import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .rnn import DelayedRNN
from text import symbols
from utils.utils import en_symbols

class Attention(nn.Module):
    def __init__(self, hp):
        super(Attention, self).__init__()
        self.M = hp.model.gmm
        self.rnn_cell = nn.LSTMCell(
            input_size=2*hp.model.hidden,
            hidden_size=hp.model.hidden
        )
        self.W_g = nn.Linear(hp.model.hidden, 3*self.M)
        
    def attention(self, h_i, memory, ksi):
        phi_hat = self.W_g(h_i)

        ksi = ksi + torch.exp(phi_hat[:, :self.M])
        beta = torch.exp(phi_hat[:, self.M:2*self.M])
        alpha = F.softmax(phi_hat[:, 2*self.M:3*self.M], dim=-1)
        
        u = memory.new_tensor(np.arange(memory.size(1)), dtype=torch.float)
        u_R = u + 1.5
        u_L = u + 0.5
        
        term1 = torch.sum(
            alpha.unsqueeze(-1) * torch.sigmoid(
                (u_R - ksi.unsqueeze(-1)) / beta.unsqueeze(-1)
            ),
            keepdim=True,
            dim=1
        )
        
        term2 = torch.sum(
            alpha.unsqueeze(-1) * torch.sigmoid(
                (u_L - ksi.unsqueeze(-1)) / beta.unsqueeze(-1)
            ),
            keepdim=True,
            dim=1
        )
        
        weights = term1 - term2
        
        context = torch.bmm(weights, memory)
        
        termination = 1 - term1.squeeze(1)

        return context, weights, termination, ksi # (B, 1, D), (B, 1, T), (B, T)

    
    
    def forward(self, input_h_c, memory):
        B, T, D = input_h_c.size()
        
        context = input_h_c.new_zeros(B, D)
        h_i, c_i  = input_h_c.new_zeros(B, D), input_h_c.new_zeros(B, D)
        ksi = input_h_c.new_zeros(B, self.M)
        
        contexts, weights = [], []
        for i in range(T):
            x = torch.cat([input_h_c[:, i], context.squeeze(1)], dim=-1)
            h_i, c_i = self.rnn_cell(x, (h_i, c_i))
            context, weight, termination, ksi = self.attention(h_i, memory, ksi)
            
            contexts.append(context)
            weights.append(weight)
            
        contexts = torch.cat(contexts, dim=1) + input_h_c
        alignment = torch.cat(weights, dim=1)
        # termination = torch.gather(termination, 1, (input_lengths-1).unsqueeze(-1)) # 4

        return contexts, alignment#, termination



class TTS(nn.Module):
    def __init__(self, hp, freq, layers):
        super(TTS, self).__init__()
        self.hp = hp

        self.W_t_0 = nn.Linear(1, hp.model.hidden)
        self.W_f_0 = nn.Linear(1, hp.model.hidden)
        self.W_c_0 = nn.Linear(freq, hp.model.hidden)
        
        self.layers = nn.ModuleList([DelayedRNN(hp) for _ in range(layers)])

        # Gaussian Mixture Model: eq. (2)
        self.K = hp.model.gmm

        # map output to produce GMM parameter eq. (10)
        self.W_theta = nn.Linear(hp.model.hidden, 3*self.K)

        if self.hp.data.name == 'KSS':
            self.embedding_text = nn.Embedding(len(symbols), hp.model.hidden)
        elif self.hp.data.name == 'Blizzard':
            self.embedding_text = nn.Embedding(len(en_symbols), hp.model.hidden)
        else:
            raise NotImplementedError

        self.text_lstm = nn.LSTM(
            input_size=hp.model.hidden,
            hidden_size=hp.model.hidden//2, 
            batch_first=True,
            bidirectional=True
        )
        
        self.attention = Attention(hp)

    def text_encode(self, text, text_lengths):
        total_length = text.size(1)
        embed = self.embedding_text(text)
        packed = nn.utils.rnn.pack_padded_sequence(
            embed,
            text_lengths,
            batch_first=True,
            enforce_sorted=False
        )
        memory, _ = self.text_lstm(packed)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(
            memory,
            batch_first=True,
            total_length=total_length
        )
        return unpacked
        
    def forward(self, x, text, text_lengths, audio_lengths):
        # Extract memory
        memory = self.text_encode(text, text_lengths)
        
        # x: [B, M, T] / B=batch, M=mel, T=time
        h_t = self.W_t_0(F.pad(x, [1, -1]).unsqueeze(-1))
        h_f = self.W_f_0(F.pad(x, [0, 0, 1, -1]).unsqueeze(-1))
        h_c = self.W_c_0(F.pad(x, [1, -1]).transpose(1, 2))
        
        # h_t, h_f: [B, M, T, D] / h_c: [B, T, D]
        for i, layer in enumerate(self.layers):
            if i != (len(self.layers)//2):
                h_t, h_f, h_c = layer(h_t, h_f, h_c, audio_lengths)
                
            else:
                h_c, alignment = self.attention(h_c, memory)
                h_t, h_f, h_c = layer(h_t, h_f, h_c, audio_lengths)

        theta_hat = self.W_theta(h_f)

        mu = theta_hat[..., :self.K] # eq. (3)
        std = theta_hat[..., self.K:2*self.K] # eq. (4)
        pi = theta_hat[..., 2*self.K:] # eq. (5)
            
        return mu, std, pi, alignment