"""
Num params: 73489960
Run it as: mpiexec -n {num_processes} python3.6 -m flows_celeba.launchers.celeba64_3bit_official from the flows master directory of the git repo. 
           num_processes=8 was used for this launcher on a 8-GPU (1080 Ti) machine with 40 GB RAM. 
If you want to use python3.5, remove the f string in the logdir.
"""

import numpy as np
import tensorflow as tf
from tensorflow.distributions import Normal
from tqdm import tqdm
from flows_imagenet.logistic import mixlogistic_logpdf, mixlogistic_logcdf, mixlogistic_invcdf
from flows_celeba import flow_training_celeba

DEFAULT_FLOATX = tf.float32
STORAGE_FLOATX = tf.float32

def to_default_floatx(x):
    return tf.cast(x, DEFAULT_FLOATX)

def at_least_float32(x):
    assert x.dtype in [tf.float16, tf.float32, tf.float64]
    if x.dtype == tf.float16:
        return tf.cast(x, tf.float32)
    return x

def get_var(var_name, *, ema, initializer, trainable=True, **kwargs):
    """forced storage dtype"""
    assert 'dtype' not in kwargs
    if isinstance(initializer, np.ndarray):
        initializer = initializer.astype(STORAGE_FLOATX.as_numpy_dtype)
    v = tf.get_variable(var_name, dtype=STORAGE_FLOATX, initializer=initializer, trainable=trainable, **kwargs)
    if ema is not None:
        assert isinstance(ema, tf.train.ExponentialMovingAverage)
        v = ema.average(v)
    return v

def _norm(x, *, axis, g, b, e=1e-5):
    assert x.shape.ndims == g.shape.ndims == b.shape.ndims
    u = tf.reduce_mean(x, axis=axis, keepdims=True)
    s = tf.reduce_mean(tf.squared_difference(x, u), axis=axis, keepdims=True)
    x = (x - u) * tf.rsqrt(s + e)
    return x * g + b

def norm(x, *, name, ema):
    """Layer norm over last axis"""
    with tf.variable_scope(name):
        dim = int(x.shape[-1])
        _g = get_var('g', ema=ema, shape=[dim], initializer=tf.constant_initializer(1))
        _b = get_var('b', ema=ema, shape=[dim], initializer=tf.constant_initializer(0))
        g, b = map(to_default_floatx, [_g, _b])
        bcast_shape = [1] * (x.shape.ndims - 1) + [dim]
        return _norm(x, g=tf.reshape(g, bcast_shape), b=tf.reshape(b, bcast_shape), axis=-1)

def int_shape(x):
    return list(map(int, x.shape.as_list()))

def sumflat(x):
    return tf.reduce_sum(tf.reshape(x, [x.shape[0], -1]), axis=1)

def inverse_sigmoid(x):
    return -tf.log(tf.reciprocal(x) - 1.)

def init_normalization(x, *, name, init_scale=1., init, ema):
    with tf.variable_scope(name):
        g = get_var('g', shape=x.shape[1:], initializer=tf.constant_initializer(1.), ema=ema)
        b = get_var('b', shape=x.shape[1:], initializer=tf.constant_initializer(0.), ema=ema)
        if init:
            # data based normalization
            m_init, v_init = tf.nn.moments(x, [0])
            scale_init = init_scale * tf.rsqrt(v_init + 1e-8)
            assert m_init.shape == v_init.shape == scale_init.shape == g.shape == b.shape
            with tf.control_dependencies([
                g.assign(scale_init),
                b.assign(-m_init * scale_init)
            ]):
                g, b = tf.identity_n([g, b])
        return g, b

def dense(x, *, name, num_units, init_scale=1., init, ema):
    with tf.variable_scope(name):
        _, in_dim = x.shape
        W = get_var('W', shape=[in_dim, num_units], initializer=tf.random_normal_initializer(0, 0.05), ema=ema)
        b = get_var('b', shape=[num_units], initializer=tf.constant_initializer(0.), ema=ema)

        if init:
            y = tf.matmul(x, W)
            m_init, v_init = tf.nn.moments(y, [0])
            scale_init = init_scale * tf.rsqrt(v_init + 1e-8)
            with tf.control_dependencies([
                W.assign(W * scale_init[None, :]),
                b.assign(-m_init * scale_init),
            ]):
                x = tf.identity(x)

        return tf.nn.bias_add(tf.matmul(x, W), b)

def conv2d(x, *, name, num_units, filter_size=(3, 3), stride=(1, 1), pad='SAME', init_scale=1., init, ema):
    with tf.variable_scope(name):
        assert x.shape.ndims == 4
        W = get_var('W', shape=[*filter_size, int(x.shape[-1]), num_units],
                    initializer=tf.random_normal_initializer(0, 0.05), ema=ema)
        b = get_var('b', shape=[num_units], initializer=tf.constant_initializer(0.), ema=ema)

        if init:
            y = tf.nn.conv2d(x, W, [1, *stride, 1], pad)
            m_init, v_init = tf.nn.moments(y, [0, 1, 2])
            scale_init = init_scale * tf.rsqrt(v_init + 1e-8)
            with tf.control_dependencies([
                W.assign(W * scale_init[None, None, None, :]),
                b.assign(-m_init * scale_init),
            ]):
                x = tf.identity(x)

        return tf.nn.bias_add(tf.nn.conv2d(x, W, [1, *stride, 1], pad), b)

def nin(x, *, num_units, **kwargs):
    assert 'num_units' not in kwargs
    s = x.shape.as_list()
    x = tf.reshape(x, [np.prod(s[:-1]), s[-1]])
    x = dense(x, num_units=num_units, **kwargs)
    return tf.reshape(x, s[:-1] + [num_units])

def matmul_last_axis(x, w):
    _, out_dim = w.shape
    s = x.shape.as_list()
    x = tf.reshape(x, [np.prod(s[:-1]), s[-1]])
    x = tf.matmul(x, w)
    return tf.reshape(x, s[:-1] + [out_dim])

def concat_elu(x, *, axis=-1):
    return tf.nn.elu(tf.concat([x, -x], axis=axis))

def gate(x, *, axis):
    a, b = tf.split(x, 2, axis=axis)
    return a * tf.sigmoid(b)

def gated_resnet(x, *, name, a, nonlinearity=concat_elu, conv=conv2d, use_nin, init, ema, dropout_p):
    with tf.variable_scope(name):
        num_filters = int(x.shape[-1])

        c1 = conv(nonlinearity(x), name='c1', num_units=num_filters, init=init, ema=ema)
        if a is not None:  # add short-cut connection if auxiliary input 'a' is given
            c1 += nin(nonlinearity(a), name='a_proj', num_units=num_filters, init=init, ema=ema)
        c1 = nonlinearity(c1)
        if dropout_p > 0:
            c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)

        c2 = (nin if use_nin else conv)(c1, name='c2', num_units=num_filters * 2, init_scale=0.1, init=init, ema=ema)
        return x + gate(c2, axis=3)

def attn(x, *, name, pos_emb, heads, init, ema, dropout_p):
    with tf.variable_scope(name):
        bs, height, width, ch = x.shape.as_list()
        assert pos_emb.shape == [height, width, ch]
        assert ch % heads == 0
        timesteps = height * width
        dim = ch // heads
        # Position embeddings
        c = x + pos_emb[None, :, :, :]
        # b, h, t, d == batch, num heads, num timesteps, per-head dim (C // heads)
        c = nin(c, name='proj1', num_units=3 * ch, init=init, ema=ema)
        assert c.shape == [bs, height, width, 3 * ch]
        # Split into heads / Q / K / V
        c = tf.reshape(c, [bs, timesteps, 3, heads, dim])  # b, t, 3, h, d
        c = tf.transpose(c, [2, 0, 3, 1, 4])  # 3, b, h, t, d
        q_bhtd, k_bhtd, v_bhtd = tf.unstack(c, axis=0)
        assert q_bhtd.shape == k_bhtd.shape == v_bhtd.shape == [bs, heads, timesteps, dim]
        # Attention
        w_bhtt = tf.matmul(q_bhtd, k_bhtd, transpose_b=True) / np.sqrt(float(dim))
        w_bhtt = tf.cast(tf.nn.softmax(at_least_float32(w_bhtt)), dtype=x.dtype)
        assert w_bhtt.shape == [bs, heads, timesteps, timesteps]
        a_bhtd = tf.matmul(w_bhtt, v_bhtd)
        # Merge heads
        a_bthd = tf.transpose(a_bhtd, [0, 2, 1, 3])
        assert a_bthd.shape == [bs, timesteps, heads, dim]
        a_btc = tf.reshape(a_bthd, [bs, timesteps, ch])
        # Project
        c1 = tf.reshape(a_btc, [bs, height, width, ch])
        if dropout_p > 0:
            c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
        c2 = nin(c1, name='proj2', num_units=ch * 2, init_scale=0.1, init=init, ema=ema)
        return x + gate(c2, axis=3)

class Flow:
    def forward(self, x, **kwargs):
        raise NotImplementedError
    def backward(self, y, **kwargs):
        raise NotImplementedError

class Inverse(Flow):
    def __init__(self, base_flow):
        self.base_flow = base_flow

    def forward(self, x, **kwargs):
        return self.base_flow.inverse(x, **kwargs)
    
    def inverse(self, y, **kwargs):
        return self.base_flow.forward(y, **kwargs)

class Compose(Flow):
    def __init__(self, flows):
        self.flows = flows

    def _maybe_tqdm(self, iterable, desc, verbose):
        return tqdm(iterable, desc=desc) if verbose else iterable

    def forward(self, x, **kwargs):
        bs = int((x[0] if isinstance(x, tuple) else x).shape[0])
        logd_terms = []
        for i, f in enumerate(self._maybe_tqdm(self.flows, desc='forward {}'.format(kwargs),
                                               verbose=kwargs.get('verbose'))):
            assert isinstance(f, Flow)
            x, l = f.forward(x, **kwargs)
            if l is not None:
                assert l.shape == [bs]
                logd_terms.append(l)
        return x, tf.add_n(logd_terms) if logd_terms else tf.constant(0.)

    def inverse(self, y, **kwargs):
        bs = int((y[0] if isinstance(y, tuple) else y).shape[0])
        logd_terms = []
        for i, f in enumerate(
                self._maybe_tqdm(self.flows[::-1], desc='inverse {}'.format(kwargs), verbose=kwargs.get('verbose'))):
            assert isinstance(f, Flow)
            y, l = f.inverse(y, **kwargs)
            if l is not None:
                assert l.shape == [bs]
                logd_terms.append(l)
        return y, tf.add_n(logd_terms) if logd_terms else tf.constant(0.)

class ImgProc(Flow):
    def forward(self, x, **kwargs):
        x = x * (.9 / 8) + .05  # [0, 8] -> [.05, .95]
        x = -tf.log(1. / x - 1.)  # inverse sigmoid
        logd = np.log(.9 / 8) + tf.nn.softplus(x) + tf.nn.softplus(-x)
        logd = tf.reduce_sum(tf.reshape(logd, [int_shape(logd)[0], -1]), 1)
        return x, logd

    def inverse(self, y, **kwargs):
        y = tf.sigmoid(y)
        logd = tf.log(y) + tf.log(1. - y)
        y = (y - .05) / (.9 / 8)  # [.05, .95] -> [0, 8]
        logd -= np.log(.9 / 8)
        logd = tf.reduce_sum(tf.reshape(logd, [int_shape(logd)[0], -1]), 1)
        return y, logd

class TupleFlip(Flow):
    def forward(self, x, **kwargs):
        assert isinstance(x, tuple)
        a, b = x
        return (b, a), None

    def inverse(self, y, **kwargs):
        assert isinstance(y, tuple)
        a, b = y
        return (b, a), None

class SpaceToDepth(Flow):
    def __init__(self, block_size=2):
        self.block_size = block_size

    def forward(self, x, **kwargs):
        return tf.space_to_depth(x, self.block_size), None

    def inverse(self, y, **kwargs):
        return tf.depth_to_space(y, self.block_size), None

class CheckerboardSplit(Flow):
    def forward(self, x, **kwargs):
        assert isinstance(x, tf.Tensor)
        B, H, W, C = x.shape
        x = tf.reshape(x, [B, H, W // 2, 2, C])
        a, b = tf.unstack(x, axis=3)
        assert a.shape == b.shape == [B, H, W // 2, C]
        return (a, b), None

    def inverse(self, y, **kwargs):
        assert isinstance(y, tuple)
        a, b = y
        assert a.shape == b.shape
        B, H, W_half, C = a.shape
        x = tf.stack([a, b], axis=3)
        assert x.shape == [B, H, W_half, 2, C]
        return tf.reshape(x, [B, H, W_half * 2, C]), None

class ChannelSplit(Flow):
    def forward(self, x, **kwargs):
        assert isinstance(x, tf.Tensor)
        assert len(x.shape) == 4 and x.shape[3] % 2 == 0
        return tuple(tf.split(x, 2, axis=3)), None

    def inverse(self, y, **kwargs):
        assert isinstance(y, tuple)
        a, b = y
        return tf.concat([a, b], axis=3), None

class Sigmoid(Flow):
    def forward(self, x, **kwargs):
        y = tf.sigmoid(x)
        logd = -tf.nn.softplus(x) - tf.nn.softplus(-x)
        return y, sumflat(logd)
    def inverse(self, y, **kwargs):
        x = inverse_sigmoid(y)
        logd = -tf.log(y) - tf.log(1. - y)
        return x, sumflat(logd)

class Norm(Flow):
    def __init__(self, init_scale=1.):
        def f(input_, forward, init, ema):
            assert not isinstance(input_, list)
            if isinstance(input_, tuple):
                is_tuple = True
            else:
                assert isinstance(input_, tf.Tensor)
                input_ = [input_]
                is_tuple = False

            bs = int(input_[0].shape[0])
            g_and_b = []
            for (i, x) in enumerate(input_):
                g, b = init_normalization(x, name='norm{}'.format(i), init_scale=init_scale, init=init, ema=ema)
                g = tf.maximum(g, 1e-10)
                assert x.shape[0] == bs and g.shape == b.shape == x.shape[1:]
                g_and_b.append((g, b))

            logd = tf.fill([bs], tf.add_n([tf.reduce_sum(tf.log(g)) for (g, _) in g_and_b]))
            if forward:
                out = [x * g[None] + b[None] for (x, (g, b)) in zip(input_, g_and_b)]
            else:
                out = [(x - b[None]) / g[None] for (x, (g, b)) in zip(input_, g_and_b)]
                logd = -logd

            if not is_tuple:
                assert len(out) == 1
                return out[0], logd
            return tuple(out), logd

        self.template = tf.make_template(self.__class__.__name__, f)

    def forward(self, x, init=False, ema=None, **kwargs):
        return self.template(x, forward=True, init=init, ema=ema)

    def inverse(self, y, init=False, ema=None, **kwargs):
        return self.template(y, forward=False, init=init, ema=ema)            

class MixLogisticCoupling(Flow):
    """
    CDF of mixture of logistics, followed by affine
    """

    def __init__(self, filters, blocks, use_nin, components, attn_heads, use_ln,
                 with_affine=True, use_final_nin=False, init_scale=0.1, nonlinearity=concat_elu):
        self.components = components
        self.with_affine = with_affine
        self.scale_flow = Inverse(Sigmoid())

        def f(x, init, ema, dropout_p, verbose, context):
            # if verbose and context is not None:
            #     print('got context')
            if init and verbose:
                # debug stuff
                with tf.variable_scope('debug'):
                    xmean, xvar = tf.nn.moments(x, axes=list(range(len(x.get_shape()))))
                    x = tf.Print(
                        x,
                        [
                            tf.shape(x), xmean, tf.sqrt(xvar), tf.reduce_min(x), tf.reduce_max(x),
                            tf.reduce_any(tf.is_nan(x)), tf.reduce_any(tf.is_inf(x))
                        ],
                        message='{} (shape/mean/std/min/max/nan/inf) '.format(self.template.variable_scope.name),
                        summarize=10,
                    )
            B, H, W, C = x.shape.as_list()

            pos_emb = to_default_floatx(get_var(
                'pos_emb', ema=ema, shape=[H, W, filters], initializer=tf.random_normal_initializer(stddev=0.01),
            ))
            x = conv2d(x, name='c1', num_units=filters, init=init, ema=ema)
            for i_block in range(blocks):
                with tf.variable_scope('block{}'.format(i_block)):
                    x = gated_resnet(
                        x, name='conv', a=context, use_nin=use_nin, init=init, ema=ema, dropout_p=dropout_p
                    )
                    if use_ln:
                        x = norm(x, name='ln1', ema=ema)
            x = nonlinearity(x)
            x = (nin if use_final_nin else conv2d)(
                x, name='c2', num_units=C * (2 + 3 * components), init_scale=init_scale, init=init, ema=ema
            )
            assert x.shape == [B, H, W, C * (2 + 3 * components)]
            x = tf.reshape(x, [B, H, W, C, 2 + 3 * components])

            x = at_least_float32(x)  # do mix-logistics in tf.float32

            s, t = tf.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1]
            ml_logits, ml_means, ml_logscales = tf.split(x[:, :, :, :, 2:], 3, axis=4)
            ml_logscales = tf.maximum(ml_logscales, -7.)

            assert s.shape == t.shape == [B, H, W, C]
            assert ml_logits.shape == ml_means.shape == ml_logscales.shape == [B, H, W, C, components]
            return s, t, ml_logits, ml_means, ml_logscales

        self.template = tf.make_template(self.__class__.__name__, f)

    def forward(self, x, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs):
        assert isinstance(x, tuple)
        cf, ef = x
        float_ef = at_least_float32(ef)
        s, t, ml_logits, ml_means, ml_logscales = self.template(
            cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context
        )

        out = tf.exp(
            mixlogistic_logcdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        )
        out, scale_logd = self.scale_flow.forward(out)
        if self.with_affine:
            assert out.shape == s.shape == t.shape
            out = tf.exp(s) * out + t

        logd = mixlogistic_logpdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        if self.with_affine:
            assert s.shape == logd.shape
            logd += s
        logd = tf.reduce_sum(tf.layers.flatten(logd), axis=1)
        assert scale_logd.shape == logd.shape
        logd += scale_logd

        out, logd = map(to_default_floatx, [out, logd])
        assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype
        return (cf, out), logd

    def inverse(self, y, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs):
        assert isinstance(y, tuple)
        cf, ef = y
        float_ef = at_least_float32(ef)
        s, t, ml_logits, ml_means, ml_logscales = self.template(
            cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context
        )

        out = float_ef
        if self.with_affine:
            out = tf.exp(-s) * (out - t)
        out, invscale_logd = self.scale_flow.inverse(out)
        out = tf.clip_by_value(out, 1e-5, 1. - 1e-5)
        out = mixlogistic_invcdf(y=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)

        logd = mixlogistic_logpdf(x=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        if self.with_affine:
            assert s.shape == logd.shape
            logd += s
        logd = -tf.reduce_sum(tf.layers.flatten(logd), axis=1)
        assert invscale_logd.shape == logd.shape
        logd += invscale_logd

        out, logd = map(to_default_floatx, [out, logd])
        assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype
        return (cf, out), logd

class MixLogisticAttnCoupling(Flow):
    """
    CDF of mixture of logistics, followed by affine
    """

    def __init__(self, filters, blocks, use_nin, components, attn_heads, use_ln,
                 with_affine=True, use_final_nin=False, init_scale=0.1, nonlinearity=concat_elu):
        self.components = components
        self.with_affine = with_affine
        self.scale_flow = Inverse(Sigmoid())

        def f(x, init, ema, dropout_p, verbose, context):
            if init and verbose:
                with tf.variable_scope('debug'):
                    xmean, xvar = tf.nn.moments(x, axes=list(range(len(x.get_shape()))))
                    x = tf.Print(
                        x,
                        [
                            tf.shape(x), xmean, tf.sqrt(xvar), tf.reduce_min(x), tf.reduce_max(x),
                            tf.reduce_any(tf.is_nan(x)), tf.reduce_any(tf.is_inf(x))
                        ],
                        message='{} (shape/mean/std/min/max/nan/inf) '.format(self.template.variable_scope.name),
                        summarize=10,
                    )
            B, H, W, C = x.shape.as_list()

            pos_emb = to_default_floatx(get_var(
                'pos_emb', ema=ema, shape=[H, W, filters], initializer=tf.random_normal_initializer(stddev=0.01),
            ))
            x = conv2d(x, name='c1', num_units=filters, init=init, ema=ema)
            for i_block in range(blocks):
                with tf.variable_scope('block{}'.format(i_block)):
                    x = gated_resnet(
                        x, name='conv', a=context, use_nin=use_nin, init=init, ema=ema, dropout_p=dropout_p
                    )
                    if use_ln:
                        x = norm(x, name='ln1', ema=ema)
                    x = attn(
                        x, name='attn', pos_emb=pos_emb, heads=attn_heads, init=init, ema=ema, dropout_p=dropout_p
                    )
                    if use_ln:
                        x = norm(x, name='ln2', ema=ema)
                    assert x.shape == [B, H, W, filters]
            x = nonlinearity(x)
            x = (nin if use_final_nin else conv2d)(
                x, name='c2', num_units=C * (2 + 3 * components), init_scale=init_scale, init=init, ema=ema
            )
            assert x.shape == [B, H, W, C * (2 + 3 * components)]
            x = tf.reshape(x, [B, H, W, C, 2 + 3 * components])

            x = at_least_float32(x)  # do mix-logistics stuff in float32

            s, t = tf.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1]
            ml_logits, ml_means, ml_logscales = tf.split(x[:, :, :, :, 2:], 3, axis=4)
            ml_logscales = tf.maximum(ml_logscales, -7.)

            assert s.shape == t.shape == [B, H, W, C]
            assert ml_logits.shape == ml_means.shape == ml_logscales.shape == [B, H, W, C, components]
            return s, t, ml_logits, ml_means, ml_logscales

        self.template = tf.make_template(self.__class__.__name__, f)

    def forward(self, x, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs):
        assert isinstance(x, tuple)
        cf, ef = x
        float_ef = at_least_float32(ef)
        s, t, ml_logits, ml_means, ml_logscales = self.template(
            cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context
        )

        out = tf.exp(
            mixlogistic_logcdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        )
        out, scale_logd = self.scale_flow.forward(out)
        if self.with_affine:
            assert out.shape == s.shape == t.shape
            out = tf.exp(s) * out + t

        logd = mixlogistic_logpdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        if self.with_affine:
            assert s.shape == logd.shape
            logd += s
        logd = tf.reduce_sum(tf.layers.flatten(logd), axis=1)
        assert scale_logd.shape == logd.shape
        logd += scale_logd

        out, logd = map(to_default_floatx, [out, logd])
        assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype
        return (cf, out), logd

    def inverse(self, y, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs):
        assert isinstance(y, tuple)
        cf, ef = y
        float_ef = at_least_float32(ef)
        s, t, ml_logits, ml_means, ml_logscales = self.template(
            cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context
        )

        out = float_ef
        if self.with_affine:
            out = tf.exp(-s) * (out - t)
        out, invscale_logd = self.scale_flow.inverse(out)
        out = tf.clip_by_value(out, 1e-5, 1. - 1e-5)
        out = mixlogistic_invcdf(y=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)

        logd = mixlogistic_logpdf(x=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        if self.with_affine:
            assert s.shape == logd.shape
            logd += s
        logd = -tf.reduce_sum(tf.layers.flatten(logd), axis=1)
        assert invscale_logd.shape == logd.shape
        logd += invscale_logd

        out, logd = map(to_default_floatx, [out, logd])
        assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype
        return (cf, out), logd

def gaussian_sample_logp(shape, dtype):
    eps = tf.random_normal(shape)
    logp = Normal(0., 1.).log_prob(eps)
    assert logp.shape == eps.shape
    logp = tf.reduce_sum(tf.layers.flatten(logp), axis=1)
    return tf.cast(eps, dtype=dtype), tf.cast(logp, dtype=dtype)

class Dequantizer(Flow):
    def __init__(self, dequant_flow):
        super().__init__()
        assert isinstance(dequant_flow, Flow)
        self.dequant_flow = dequant_flow

        def deep_processor(x, *, init, ema, dropout_p):
            (this, that), _ = CheckerboardSplit().forward(x)
            processed_context = conv2d(tf.concat([this, that], 3), name='proj', num_units=32, init=init, ema=ema)
            for i in range(5):
                processed_context = gated_resnet(
                    processed_context, name='c{}'.format(i),
                    a=None, dropout_p=dropout_p, ema=ema, init=init,
                    use_nin=False
                )
                processed_context = norm(processed_context, name='dqln{}'.format(i), ema=ema)
                
            return processed_context

        self.context_proc = tf.make_template("context_proc", deep_processor)

    def forward(self, x, init=False, ema=None, dropout_p=0., verbose=True, **kwargs):
        eps, eps_logli = gaussian_sample_logp(x.shape, dtype=DEFAULT_FLOATX)
        unbound_xd, logd = self.dequant_flow.forward(
            eps,
            context=self.context_proc(x / 8.0 - 0.5, init=init, ema=ema, dropout_p=dropout_p),
            init=init, ema=ema, dropout_p=dropout_p, verbose=verbose
        )
        xd, sigmoid_logd = Sigmoid().forward(unbound_xd)
        assert x.shape == xd.shape and logd.shape == sigmoid_logd.shape == eps_logli.shape
        return x + xd, logd + sigmoid_logd - eps_logli


def construct(*, filters, blocks, components, attn_heads, use_nin, use_ln):
    dequant_coupling_kwargs = dict(
        filters=filters, blocks=5, use_nin=use_nin, components=components, attn_heads=attn_heads, use_ln=use_ln
    )
    dequant_flow = Dequantizer(Compose([
        CheckerboardSplit(),
        Norm(), 
        MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(),
        Inverse(CheckerboardSplit()),
    ]))

    coupling_kwargs = dict(
        filters=filters, blocks=blocks, use_nin=use_nin, components=components, attn_heads=attn_heads, use_ln=use_ln
    )
    flow = Compose([
        ImgProc(),

        SpaceToDepth(),

        CheckerboardSplit(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Inverse(CheckerboardSplit()),

        SpaceToDepth(),

        ChannelSplit(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Inverse(ChannelSplit()),

        CheckerboardSplit(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Inverse(CheckerboardSplit()),

        SpaceToDepth(),
        
        ChannelSplit(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Inverse(ChannelSplit()),

        CheckerboardSplit(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Norm(), 
        MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(),
        Inverse(CheckerboardSplit()),
    
    ])
    return dequant_flow, flow

def main():
    global DEFAULT_FLOATX
    DEFAULT_FLOATX = tf.float32

    max_lr = 4e-5
    warmup_steps = 20000
    bs = 56 
    # set this to a smaller value if it can't fit on your GPU.
    # make sure bs % num_mpi_processes == 0. There will be an assertion error otherwise. 

    def lr_schedule(step, *, decay=0.9995):
        """Ramp up to 4e-5 in 20K steps, stay there till 50K, geometric decay to 1e-5 by 55K steps, stay there""" 
        global curr_lr
        if step < warmup_steps:
            return max_lr * step / warmup_steps
        elif step >= warmup_steps and step <= (2.5 * warmup_steps):
            curr_lr =  max_lr
            return max_lr
        elif step > (2.5 * warmup_steps) and curr_lr > 1e-5:
            curr_lr *= decay
            return curr_lr
        return curr_lr

    dropout_p = 0.
    filters = 96
    blocks = 16
    components = 4  # logistic mixture components
    attn_heads = 4
    use_ln = True
 
    floatx_str = {tf.float32: 'fp32', tf.float16: 'fp16'}[DEFAULT_FLOATX]
    flow_training_celeba.train(
        flow_constructor=lambda: construct(
            filters=filters,
            components=components,
            attn_heads=attn_heads,
            blocks=blocks,
            use_nin=True,
            use_ln=use_ln
        ),
        logdir=f'~/logs/2018-11-12/celeba64_3bit_ELU_code_release_mix{components}_b{blocks}_f{filters}_h{attn_heads}_ln{int(use_ln)}_lr{max_lr}_bs{bs}_drop{dropout_p}_{floatx_str}',
        lr_schedule=lr_schedule,
        dropout_p=dropout_p,
        seed=0,
        init_bs=56, # set this to a smaller value if it can't fit on your GPU.
        dataset='celeba64_3bit',
        total_bs=bs,
        ema_decay=.999,
        steps_per_log=100,
        steps_per_dump=5000,
        steps_per_samples=100,
        max_grad_norm=1.,
        dtype=DEFAULT_FLOATX,
        scale_loss=1e-2 if DEFAULT_FLOATX == tf.float16 else None,
        n_epochs=20, # should start seeing good samples by the end of 5 epochs.
        restore_checkpoint=None, # put in path to checkpoint in the format: path_to_checkpoint/model (no .meta / .ckpt)
        save_jpg=True, # turn this to True/False based on whether you want jpg version of low-bit samples saved while the model's training.
    )

if __name__ == '__main__':
    main()