import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from . import SyllableBaseModel

class Model(SyllableBaseModel):

    def __init__(self, no_vocabs, embedding_dim=16, window_size=1):
        super(Model, self).__init__()

        self.input_size = 2*window_size+1
        
        self.embeddings = nn.ModuleList([
            nn.Embedding(
                no_vocabs,
                embedding_dim,
                padding_idx=0
            ) for i in range(self.input_size)
        ])

        self.pooling = nn.AvgPool1d(embedding_dim)

        self.linear1 = nn.Linear(embedding_dim, 8)
        self.linear2 = nn.Linear(8, 1)

    def forward(self, inputs):
        embeds = list(map(
            lambda p: p[1](inputs[:, p[0]]), zip(range(self.input_size), self.embeddings)
        ))

        x = torch.stack(embeds, 2)
        x = self.pooling(x).view(x.size()[0], -1)

        out = F.relu(self.linear1(x))
        out = self.linear2(out)

        return out