import numpy as np

from torch import nn
from torch.nn import functional as F

import utils


class UNet(nn.Module):
    def __init__(self,
                 in_features,
                 max_hidden_features,
                 num_layers,
                 out_features,
                 nonlinearity=F.relu):
        super().__init__()

        assert utils.is_power_of_two(max_hidden_features), \
            '\'max_hidden_features\' must be a power of two.'
        assert max_hidden_features // 2 ** num_layers > 1, \
            '\'num_layers\' must be {} or fewer'.format(int(np.log2(max_hidden_features) - 1))

        self.nonlinearity = nonlinearity
        self.num_layers = num_layers

        self.initial_layer = nn.Linear(in_features, max_hidden_features)

        self.down_layers = nn.ModuleList([
            nn.Linear(
                in_features=max_hidden_features // 2 ** i,
                out_features=max_hidden_features // 2 ** (i + 1)
            )
            for i in range(num_layers)
        ])

        self.middle_layer = nn.Linear(
            in_features=max_hidden_features // 2 ** num_layers,
            out_features=max_hidden_features // 2 ** num_layers)

        self.up_layers = nn.ModuleList([
            nn.Linear(
                in_features=max_hidden_features // 2 ** (i + 1),
                out_features=max_hidden_features // 2 ** i
            )
            for i in range(num_layers - 1, -1, -1)
        ])

        self.final_layer = nn.Linear(max_hidden_features, out_features)

    def forward(self, inputs):
        temps = self.initial_layer(inputs)
        temps = self.nonlinearity(temps)

        down_temps = []
        for layer in self.down_layers:
            temps = layer(temps)
            temps = self.nonlinearity(temps)
            down_temps.append(temps)

        temps = self.middle_layer(temps)
        temps = self.nonlinearity(temps)

        for i, layer in enumerate(self.up_layers):
            temps += down_temps[self.num_layers - i - 1]
            temps = self.nonlinearity(temps)
            temps = layer(temps)

        return self.final_layer(temps)