"""MADE and ResMADE."""
import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class MaskedLinear(nn.Linear):

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.register_buffer('mask', torch.ones(out_features, in_features))

        self.masked_weight = None

    def set_mask(self, mask):
        """Accepts a mask of shape [in_features, out_features]."""
        self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))

    def forward(self, input):
        if self.masked_weight is None:
            return F.linear(input, self.mask * self.weight, self.bias)
        else:
            # ~17% speedup for Prog Sampling.
            return F.linear(input, self.masked_weight, self.bias)


class MaskedResidualBlock(nn.Module):

    def __init__(self, in_features, out_features, activation):
        assert in_features == out_features, [in_features, out_features]
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(MaskedLinear(in_features, out_features, bias=True))
        self.layers.append(MaskedLinear(in_features, out_features, bias=True))
        self.activation = activation

    def set_mask(self, mask):
        self.layers[0].set_mask(mask)
        self.layers[1].set_mask(mask)

    def forward(self, input):
        out = input
        out = self.activation(out)
        out = self.layers[0](out)
        out = self.activation(out)
        out = self.layers[1](out)
        return input + out


class MADE(nn.Module):

    def __init__(
            self,
            nin,
            hidden_sizes,
            nout,
            num_masks=1,
            natural_ordering=True,
            input_bins=None,
            activation=nn.ReLU,
            do_direct_io_connections=False,
            input_encoding=None,
            output_encoding='one_hot',
            embed_size=32,
            input_no_emb_if_leq=True,
            residual_connections=False,
            column_masking=False,
            seed=11123,
            fixed_ordering=None,
    ):
        """MADE.

        Args:
          nin: integer; number of input variables.  Each input variable
            represents a column.
          hidden sizes: a list of integers; number of units in hidden layers.
          nout: integer; number of outputs, the sum of all input variables'
            domain sizes.
          num_masks: number of orderings + connectivity masks to cycle through.
          natural_ordering: force natural ordering of dimensions, don't use
            random permutations.
          input_bins: classes each input var can take on, e.g., [5, 2] means
            input x1 has values in {0, ..., 4} and x2 in {0, 1}.  In other
            words, the domain sizes.
          activation: the activation to use.
          do_direct_io_connections: whether to add a connection from inputs to
            output layer.  Helpful for information flow.
          input_encoding: input encoding mode, see EncodeInput().
          output_encoding: output logits decoding mode, either 'embed' or
            'one_hot'.  See logits_for_col().
          embed_size: int, embedding dim.
          input_no_emb_if_leq: optimization, whether to turn off embedding for
            variables that have a domain size less than embed_size.  If so,
            those variables would have no learnable embeddings and instead are
            encoded as one hot vecs.
          residual_connections: use ResMADE?  Could lead to faster learning.
          column_masking: if True, turn on column masking during training time,
            which enables the wildcard skipping optimization during inference.
            Recommended to be set for any non-trivial datasets.
          seed: seed for generating random connectivity masks.
          fixed_ordering: variable ordering to use.  If specified, order[i]
            maps natural index i -> position in ordering.  E.g., if order[0] =
            2, variable 0 is placed at position 2.
        """
        super().__init__()
        print('fixed_ordering', fixed_ordering, 'seed', seed,
              'natural_ordering', natural_ordering)
        self.nin = nin
        assert input_encoding in [None, 'one_hot', 'binary', 'embed']
        self.input_encoding = input_encoding
        assert output_encoding in ['one_hot', 'embed']
        self.embed_size = self.emb_dim = embed_size
        self.output_encoding = output_encoding
        self.activation = activation
        self.nout = nout
        self.hidden_sizes = hidden_sizes
        self.input_bins = input_bins
        self.input_no_emb_if_leq = input_no_emb_if_leq
        self.do_direct_io_connections = do_direct_io_connections
        self.column_masking = column_masking
        self.residual_connections = residual_connections

        self.fixed_ordering = fixed_ordering
        if fixed_ordering is not None:
            assert num_masks == 1
            print('** Fixed ordering {} supplied, ignoring natural_ordering'.
                  format(fixed_ordering))

        assert self.input_bins
        encoded_bins = list(
            map(self._get_output_encoded_dist_size, self.input_bins))
        self.input_bins_encoded = list(
            map(self._get_input_encoded_dist_size, self.input_bins))
        self.input_bins_encoded_cumsum = np.cumsum(
            list(map(self._get_input_encoded_dist_size, self.input_bins)))
        print('encoded_bins (output)', encoded_bins)
        print('encoded_bins (input)', self.input_bins_encoded)

        hs = [nin] + hidden_sizes + [sum(encoded_bins)]
        self.net = []
        for h0, h1 in zip(hs, hs[1:]):
            if residual_connections:
                if h0 == h1:
                    self.net.extend([
                        MaskedResidualBlock(
                            h0, h1, activation=activation(inplace=False))
                    ])
                else:
                    self.net.extend([
                        MaskedLinear(h0, h1),
                    ])
            else:
                self.net.extend([
                    MaskedLinear(h0, h1),
                    activation(inplace=True),
                ])
        if not residual_connections:
            self.net.pop()
        self.net = nn.Sequential(*self.net)

        if self.input_encoding is not None:
            # Input layer should be changed.
            assert self.input_bins is not None
            input_size = 0
            for i, dist_size in enumerate(self.input_bins):
                input_size += self._get_input_encoded_dist_size(dist_size)
            new_layer0 = MaskedLinear(input_size, self.net[0].out_features)
            self.net[0] = new_layer0

        if self.output_encoding == 'embed':
            assert self.input_encoding == 'embed'

        if self.input_encoding == 'embed':
            self.embeddings = nn.ModuleList()
            for i, dist_size in enumerate(self.input_bins):
                if dist_size <= self.embed_size and self.input_no_emb_if_leq:
                    embed = None
                else:
                    embed = nn.Embedding(dist_size, self.embed_size)
                self.embeddings.append(embed)

        # Learnable [MASK] representation.
        if self.column_masking:
            self.unk_embeddings = nn.ParameterList()
            for i, dist_size in enumerate(self.input_bins):
                self.unk_embeddings.append(
                    nn.Parameter(torch.zeros(1, self.input_bins_encoded[i])))

        self.natural_ordering = natural_ordering
        self.num_masks = num_masks
        self.seed = seed if seed is not None else 11123
        self.init_seed = self.seed

        self.direct_io_layer = None
        self.logit_indices = np.cumsum(encoded_bins)
        self.m = {}

        self.update_masks()
        self.orderings = [self.m[-1]]

        # Optimization: cache some values needed in EncodeInput().
        self.bin_as_onehot_shifts = None

    def _build_or_update_direct_io(self):
        assert self.nout > self.nin and self.input_bins is not None
        direct_nin = self.net[0].in_features
        direct_nout = self.net[-1].out_features
        if self.direct_io_layer is None:
            self.direct_io_layer = MaskedLinear(direct_nin, direct_nout)
        mask = np.zeros((direct_nout, direct_nin), dtype=np.uint8)

        if self.natural_ordering:
            curr = 0
            for i in range(self.nin):
                dist_size = self._get_input_encoded_dist_size(
                    self.input_bins[i])
                # Input i connects to groups > i.
                mask[self.logit_indices[i]:, curr:dist_size] = 1
                curr += dist_size
        else:
            # Inverse: ord_idx -> natural idx.
            inv_ordering = [None] * self.nin
            for natural_idx in range(self.nin):
                inv_ordering[self.m[-1][natural_idx]] = natural_idx

            for ord_i in range(self.nin):
                nat_i = inv_ordering[ord_i]
                # x_(nat_i) in the input occupies range [inp_l, inp_r).
                inp_l = 0 if nat_i == 0 else self.input_bins_encoded_cumsum[
                    nat_i - 1]
                inp_r = self.input_bins_encoded_cumsum[nat_i]
                assert inp_l < inp_r

                for ord_j in range(ord_i + 1, self.nin):
                    nat_j = inv_ordering[ord_j]
                    # Output x_(nat_j) should connect to input x_(nat_i); it
                    # occupies range [out_l, out_r) in the output.
                    out_l = 0 if nat_j == 0 else self.logit_indices[nat_j - 1]
                    out_r = self.logit_indices[nat_j]
                    assert out_l < out_r
                    mask[out_l:out_r, inp_l:inp_r] = 1
        mask = mask.T
        self.direct_io_layer.set_mask(mask)

    def _get_input_encoded_dist_size(self, dist_size):
        if self.input_encoding == 'embed':
            if self.input_no_emb_if_leq:
                dist_size = min(dist_size, self.embed_size)
            else:
                dist_size = self.embed_size
        elif self.input_encoding == 'one_hot':
            pass
        elif self.input_encoding == 'binary':
            dist_size = max(1, int(np.ceil(np.log2(dist_size))))
        elif self.input_encoding is None:
            return 1
        else:
            assert False, self.input_encoding
        return dist_size

    def _get_output_encoded_dist_size(self, dist_size):
        if self.output_encoding == 'embed':
            if self.input_no_emb_if_leq:
                dist_size = min(dist_size, self.embed_size)
            else:
                dist_size = self.embed_size
        elif self.output_encoding == 'one_hot':
            pass
        elif self.output_encoding == 'binary':
            dist_size = max(1, int(np.ceil(np.log2(dist_size))))
        return dist_size

    def update_masks(self, invoke_order=None):
        """Update m() for all layers and change masks correspondingly.

        No-op if "self.num_masks" is 1.
        """
        if self.m and self.num_masks == 1:
            return
        L = len(self.hidden_sizes)

        ### Precedence of several params determining ordering:
        #
        # invoke_order
        # orderings
        # fixed_ordering
        # natural_ordering
        #
        # from high precedence to low.

        if invoke_order is not None:
            found = False
            for i in range(len(self.orderings)):
                if np.array_equal(self.orderings[i], invoke_order):
                    found = True
                    break
            assert found, 'specified={}, avail={}'.format(
                ordering, self.orderings)
            # orderings = [ o0, o1, o2, ... ]
            # seeds = [ init_seed, init_seed+1, init_seed+2, ... ]
            rng = np.random.RandomState(self.init_seed + i)
            self.seed = (self.init_seed + i + 1) % self.num_masks
            self.m[-1] = invoke_order
        elif hasattr(self, 'orderings'):
            # Cycle through the special orderings.
            rng = np.random.RandomState(self.seed)
            self.seed = (self.seed + 1) % self.num_masks
            self.m[-1] = self.orderings[self.seed % 4]
        else:
            rng = np.random.RandomState(self.seed)
            self.seed = (self.seed + 1) % self.num_masks
            self.m[-1] = np.arange(
                self.nin) if self.natural_ordering else rng.permutation(
                    self.nin)
            if self.fixed_ordering is not None:
                self.m[-1] = np.asarray(self.fixed_ordering)

        if self.nin > 1:
            for l in range(L):
                if self.residual_connections:
                    # Sequential assignment for ResMade: https://arxiv.org/pdf/1904.05626.pdf
                    self.m[l] = np.array([(k - 1) % (self.nin - 1)
                                          for k in range(self.hidden_sizes[l])])
                else:
                    # Samples from [0, ncols - 1).
                    self.m[l] = rng.randint(self.m[l - 1].min(),
                                            self.nin - 1,
                                            size=self.hidden_sizes[l])
        else:
            # This should result in first layer's masks == 0.
            # So output units are disconnected to any inputs.
            for l in range(L):
                self.m[l] = np.asarray([-1] * self.hidden_sizes[l])

        masks = [self.m[l - 1][:, None] <= self.m[l][None, :] for l in range(L)]
        masks.append(self.m[L - 1][:, None] < self.m[-1][None, :])

        if self.nout > self.nin:
            # Last layer's mask needs to be changed.

            if self.input_bins is None:
                k = int(self.nout / self.nin)
                # Replicate the mask across the other outputs
                # so [x1, x2, ..., xn], ..., [x1, x2, ..., xn].
                masks[-1] = np.concatenate([masks[-1]] * k, axis=1)
            else:
                # [x1, ..., x1], ..., [xn, ..., xn] where the i-th list has
                # input_bins[i - 1] many elements (multiplicity, # of classes).
                mask = np.asarray([])
                for k in range(masks[-1].shape[0]):
                    tmp_mask = []
                    for idx, x in enumerate(zip(masks[-1][k], self.input_bins)):
                        mval, nbins = x[0], self._get_output_encoded_dist_size(
                            x[1])
                        tmp_mask.extend([mval] * nbins)
                    tmp_mask = np.asarray(tmp_mask)
                    if k == 0:
                        mask = tmp_mask
                    else:
                        mask = np.vstack([mask, tmp_mask])
                masks[-1] = mask

        if self.input_encoding is not None:
            # Input layer's mask should be changed.

            assert self.input_bins is not None
            # [nin, hidden].
            mask0 = masks[0]
            new_mask0 = []
            for i, dist_size in enumerate(self.input_bins):
                dist_size = self._get_input_encoded_dist_size(dist_size)
                # [dist size, hidden]
                new_mask0.append(
                    np.concatenate([mask0[i].reshape(1, -1)] * dist_size,
                                   axis=0))
            # [sum(dist size), hidden]
            new_mask0 = np.vstack(new_mask0)
            masks[0] = new_mask0

        layers = [
            l for l in self.net if isinstance(l, MaskedLinear) or
            isinstance(l, MaskedResidualBlock)
        ]
        assert len(layers) == len(masks), (len(layers), len(masks))
        for l, m in zip(layers, masks):
            l.set_mask(m)

        if self.do_direct_io_connections:
            self._build_or_update_direct_io()

    def name(self):
        n = 'made'
        if self.residual_connections:
            n += '-resmade'
        n += '-hidden' + '_'.join(str(h) for h in self.hidden_sizes)
        n += '-emb' + str(self.embed_size)
        if self.num_masks > 1:
            n += '-{}masks'.format(self.num_masks)
        if not self.natural_ordering:
            n += '-nonNatural'
        n += ('-no' if not self.do_direct_io_connections else '-') + 'directIo'
        n += '-{}In{}Out'.format(self.input_encoding, self.output_encoding)
        if self.input_no_emb_if_leq:
            n += '-inputNoEmbIfLeq'
        if self.column_masking:
            n += '-colmask'
        return n

    def Embed(self, data, natural_col=None, out=None):
        if data is None:
            if out is None:
                return self.unk_embeddings[natural_col]
            out.copy_(self.unk_embeddings[natural_col])
            return out

        bs = data.size()[0]
        y_embed = []
        data = data.long()

        if natural_col is not None:
            # Fast path only for inference.  One col.

            coli_dom_size = self.input_bins[natural_col]
            # Embed?
            if coli_dom_size >= self.embed_size or not self.input_no_emb_if_leq:
                res = self.embeddings[natural_col](data.view(-1,))
                if out is not None:
                    out.copy_(res)
                    return out
                return res
            else:
                if out is None:
                    out = torch.zeros(bs, coli_dom_size, device=data.device)

                out.scatter_(1, data, 1)
                return out
        else:
            for i, coli_dom_size in enumerate(self.input_bins):
                # Wildcard column? use -1 as special token.
                # Inference pass only (see estimators.py).
                skip = data[0][i] < 0

                # Embed?
                if coli_dom_size >= self.embed_size or not self.input_no_emb_if_leq:
                    col_i_embs = self.embeddings[i](data[:, i])
                    if not self.column_masking:
                        y_embed.append(col_i_embs)
                    else:
                        dropped_repr = self.unk_embeddings[i]

                        def dropout_p():
                            return np.random.randint(0, self.nin) / self.nin

                        # During training, non-dropped 1's are scaled by
                        # 1/(1-p), so we clamp back to 1.
                        batch_mask = torch.clamp(
                            torch.dropout(torch.ones(bs, 1, device=data.device),
                                          p=dropout_p(),
                                          train=self.training), 0, 1)
                        y_embed.append(batch_mask * col_i_embs +
                                       (1. - batch_mask) * dropped_repr)
                else:
                    if skip:
                        y_embed.append(self.unk_embeddings[i])
                        continue
                    y_onehot = torch.zeros(bs,
                                           coli_dom_size,
                                           device=data.device)
                    y_onehot.scatter_(1, data[:, i].view(-1, 1), 1)
                    if self.column_masking:

                        def dropout_p():
                            return np.random.randint(0, self.nin) / self.nin

                        # During training, non-dropped 1's are scaled by
                        # 1/(1-p), so we clamp back to 1.
                        batch_mask = torch.clamp(
                            torch.dropout(torch.ones(bs, 1, device=data.device),
                                          p=dropout_p(),
                                          train=self.training), 0, 1)
                        y_embed.append(batch_mask * y_onehot +
                                       (1. - batch_mask) *
                                       self.unk_embeddings[i])
                    else:
                        y_embed.append(y_onehot)
            return torch.cat(y_embed, 1)

    def ToOneHot(self, data):
        assert not self.column_masking, 'not implemented'
        bs = data.size()[0]
        y_onehots = []
        data = data.long()
        for i, coli_dom_size in enumerate(self.input_bins):
            if coli_dom_size <= 2:
                y_onehots.append(data[:, i].view(-1, 1).float())
            else:
                y_onehot = torch.zeros(bs, coli_dom_size, device=data.device)
                y_onehot.scatter_(1, data[:, i].view(-1, 1), 1)
                y_onehots.append(y_onehot)

        # [bs, sum(dist size)]
        return torch.cat(y_onehots, 1)

    def ToBinaryAsOneHot(self, data, threshold=0, natural_col=None, out=None):
        if data is None:
            if out is None:
                return self.unk_embeddings[natural_col]
            out.copy_(self.unk_embeddings[natural_col])
            return out

        bs = data.size()[0]
        data = data.long()
        if self.bin_as_onehot_shifts is None:
            # This caching gives very sizable gains.
            self.bin_as_onehot_shifts = [None] * self.nin
            const_one = torch.ones([], dtype=torch.long, device=data.device)
            for i, coli_dom_size in enumerate(self.input_bins):
                # Max with 1 to guard against cols with 1 distinct val.
                one_hot_dims = max(1, int(np.ceil(np.log2(coli_dom_size))))
                self.bin_as_onehot_shifts[i] = const_one << torch.arange(
                    one_hot_dims, device=data.device)

        if natural_col is None:
            # Train path.

            assert out is None
            y_onehots = [None] * self.nin
            for i, coli_dom_size in enumerate(self.input_bins):
                if coli_dom_size > threshold:
                    # Bit shift in PyTorch + GPU is 27% faster than np.
                    data_np = data.narrow(1, i, 1)
                    binaries = (data_np & self.bin_as_onehot_shifts[i]) > 0
                    y_onehots[i] = binaries

                    if self.column_masking:
                        dropped_repr = self.unk_embeddings[i]

                        def dropout_p():
                            return np.random.randint(0, self.nin) / self.nin

                        # During training, non-dropped 1's are scaled by
                        # 1/(1-p), so we clamp back to 1.
                        batch_mask = torch.clamp(
                            torch.dropout(torch.ones(bs, 1, device=data.device),
                                          p=dropout_p(),
                                          train=self.training), 0, 1)
                        binaries = binaries.to(torch.float32,
                                               non_blocking=True,
                                               copy=False)
                        y_onehots[i] = batch_mask * binaries + (
                            1. - batch_mask) * dropped_repr

                else:
                    # Encode as plain one-hot.
                    y_onehot = torch.zeros(bs,
                                           coli_dom_size,
                                           device=data.device)
                    y_onehot.scatter_(1, data[:, i].view(-1, 1), 1)
                    y_onehots[i] = y_onehot

            res = torch.cat(y_onehots, 1)
            return res.to(torch.float32, non_blocking=True, copy=False)

        else:
            # Inference path.
            natural_idx = natural_col
            coli_dom_size = self.input_bins[natural_idx]
            if coli_dom_size > threshold:
                # Bit shift in PyTorch + GPU is 27% faster than np.
                data_np = data
                if out is None:
                    res = (data_np & self.bin_as_onehot_shifts[natural_idx]) > 0
                    return res.to(torch.float32, non_blocking=True, copy=False)
                else:
                    out.copy_(
                        (data_np & self.bin_as_onehot_shifts[natural_idx]) > 0)
                    return out
            else:
                assert False, 'inference'
                if out is None:
                    y_onehot = torch.zeros(bs,
                                           coli_dom_size,
                                           device=data.device)
                    y_onehot.scatter_(1, data, 1)
                    res = y_onehot
                    return res.to(torch.float32, non_blocking=True, copy=False)

                out.scatter_(1, data, 1)
                return out

    def EncodeInput(self, data, natural_col=None, out=None):
        """"Warning: this could take up a significant portion of a forward pass.

        Args:
          natural_col: if specified, 'data' has shape [N, 1] corresponding to
              col-'natural-col'.  Otherwise 'data' corresponds to all cols.
          out: if specified, assign results into this Tensor storage.
        """
        if self.input_encoding == 'binary':
            return self.ToBinaryAsOneHot(data, natural_col=natural_col, out=out)
        elif self.input_encoding == 'embed':
            return self.Embed(data, natural_col=natural_col, out=out)
        elif self.input_encoding is None:
            return data
        elif self.input_encoding == 'one_hot':
            return self.ToOneHot(data)
        else:
            assert False, self.input_encoding

    def forward(self, x):
        """Calculates unnormalized logits.

        If self.input_bins is not specified, the output units are ordered as:
            [x1, x2, ..., xn], ..., [x1, x2, ..., xn].
        So they can be reshaped as thus and passed to a cross entropy loss:
            out.view(-1, model.nout // model.nin, model.nin)

        Otherwise, they are ordered as:
            [x1, ..., x1], ..., [xn, ..., xn]
        And they can't be reshaped directly.

        Args:
          x: [bs, ncols].
        """
        x = self.EncodeInput(x)

        if self.direct_io_layer is not None:
            residual = self.direct_io_layer(x)
            return self.net(x) + residual

        return self.net(x)

    def forward_with_encoded_input(self, x):

        if self.direct_io_layer is not None:
            residual = self.direct_io_layer(x)
            return self.net(x) + residual

        return self.net(x)

    def logits_for_col(self, idx, logits):
        """Returns the logits (vector) corresponding to log p(x_i | x_(<i)).

        Args:
          idx: int, in natural (table) ordering.
          logits: [batch size, hidden] where hidden can either be sum(dom
            sizes), or emb_dims.

        Returns:
          logits_for_col: [batch size, domain size for column idx].
        """
        assert self.input_bins is not None

        if idx == 0:
            logits_for_var = logits[:, :self.logit_indices[0]]
        else:
            logits_for_var = logits[:, self.logit_indices[idx - 1]:self.
                                    logit_indices[idx]]
        if self.output_encoding != 'embed':
            return logits_for_var

        embed = self.embeddings[idx]

        if embed is None:
            # Can be None for small-domain columns.
            return logits_for_var

        # Otherwise, dot with embedding matrix to get the true logits.
        # [bs, emb] * [emb, dom size for idx]
        return torch.matmul(logits_for_var, embed.weight.t())

    def nll(self, logits, data):
        """Calculates -log p(data), given logits (the conditionals).

        Args:
          logits: [batch size, hidden] where hidden can either be sum(dom
            sizes), or emb_dims.
          data: [batch size, nin].

        Returns:
          nll: [batch size].
        """
        if data.dtype != torch.long:
            data = data.long()
        nll = torch.zeros(logits.size()[0], device=logits.device)
        for i in range(self.nin):
            logits_i = self.logits_for_col(i, logits)
            nll += F.cross_entropy(logits_i, data[:, i], reduction='none')

        return nll

    def sample(self, num=1, device=None):
        assert self.natural_ordering
        assert self.input_bins and self.nout > self.nin
        with torch.no_grad():
            sampled = torch.zeros((num, self.nin), device=device)
            indices = np.cumsum(self.input_bins)
            for i in range(self.nin):
                logits = self.forward(sampled)
                s = torch.multinomial(
                    torch.softmax(self.logits_for_i(i, logits), -1), 1)
                sampled[:, i] = s.view(-1,)
        return sampled


if __name__ == '__main__':
    # Checks for the autoregressive property.
    rng = np.random.RandomState(14)
    # (nin, hiddens, nout, input_bins, direct_io)
    configs_with_input_bins = [
        (2, [10], 2 + 5, [2, 5], False),
        (2, [10, 30], 2 + 5, [2, 5], False),
        (3, [6], 2 + 2 + 2, [2, 2, 2], False),
        (3, [4, 4], 2 + 1 + 2, [2, 1, 2], False),
        (4, [16, 8, 16], 2 + 3 + 1 + 2, [2, 3, 1, 2], False),
        (2, [10], 2 + 5, [2, 5], True),
        (2, [10, 30], 2 + 5, [2, 5], True),
        (3, [6], 2 + 2 + 2, [2, 2, 2], True),
        (3, [4, 4], 2 + 1 + 2, [2, 1, 2], True),
        (4, [16, 8, 16], 2 + 3 + 1 + 2, [2, 3, 1, 2], True),
    ]
    for nin, hiddens, nout, input_bins, direct_io in configs_with_input_bins:
        print(nin, hiddens, nout, input_bins, direct_io, '...', end='')
        model = MADE(nin,
                     hiddens,
                     nout,
                     input_bins=input_bins,
                     natural_ordering=True,
                     do_direct_io_connections=direct_io)
        model.eval()
        print(model)
        for k in range(nout):
            inp = torch.tensor(rng.rand(1, nin).astype(np.float32),
                               requires_grad=True)
            loss = model(inp)
            l = loss[0, k]
            l.backward()
            depends = (inp.grad[0].numpy() != 0).astype(np.uint8)

            depends_ix = np.where(depends)[0].astype(np.int32)
            var_idx = np.argmax(k < np.cumsum(input_bins))
            prev_idxs = np.arange(var_idx).astype(np.int32)

            # Asserts that k depends only on < var_idx.
            print('depends', depends_ix, 'prev_idxs', prev_idxs)
            assert len(torch.nonzero(inp.grad[0, var_idx:])) == 0
        print('ok')
    print('[MADE] Passes autoregressive-ness check!')