#
# Author: Tiberiu Boros
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import tqdm
import numpy as np
from models.clarinet.wavenet import Wavenet
from models.clarinet.modules import GaussianLoss, stft, KL_Loss
from models.clarinet.wavenet_iaf import Wavenet_Student
from torch.distributions.normal import Normal

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def _create_batches(y_target, mgc, batch_size, UPSAMPLE_COUNT=256, mgc_order=60):
    x_list = []
    y_list = []
    c_list = []

    x_mini_list = []
    y_mini_list = []
    c_mini_list = []
    mini_batch = 25

    for batch_index in range((len(mgc) - 1) // mini_batch):
        mgc_start = batch_index * mini_batch
        mgc_stop = batch_index * mini_batch + mini_batch
        c_mini_list.append(mgc[mgc_start:mgc_stop].reshape(mini_batch, mgc_order).transpose())
        x_start = batch_index * mini_batch * UPSAMPLE_COUNT
        x_stop = batch_index * mini_batch * UPSAMPLE_COUNT + mini_batch * UPSAMPLE_COUNT
        x_mini_list.append(y_target[x_start:x_stop].reshape(1, x_stop - x_start))
        y_mini_list.append(y_target[x_start:x_stop].reshape(x_stop - x_start, 1))

        if len(c_mini_list) == batch_size:
            x_list.append(torch.tensor(x_mini_list).to(device))
            y_list.append(torch.tensor(y_mini_list).to(device))
            c_list.append(torch.tensor(c_mini_list).to(device))
            c_mini_list = []
            x_mini_list = []
            y_mini_list = []

    if len(c_mini_list) != 0:
        x_list.append(torch.tensor(x_mini_list).to(device))
        y_list.append(torch.tensor(y_mini_list).to(device))
        c_list.append(torch.tensor(c_mini_list).to(device))
    # from ipdb import set_trace
    # set_trace()

    return x_list, y_list, c_list


class WavenetVocoder:
    def __init__(self, params):

        self.params = params

        self.UPSAMPLE_COUNT = 256
        self.RECEPTIVE_SIZE = 3 * 3 * 3 * 3 * 3 * 3
        self.model = Wavenet(out_channels=2,
                             num_blocks=4,
                             num_layers=6,
                             residual_channels=128,
                             gate_channels=256,
                             skip_channels=128,
                             kernel_size=3,
                             cin_channels=params.mgc_order,
                             upsample_scales=[16, 16]).to(device)

        self.loss = GaussianLoss()

        self.trainer = torch.optim.Adam(self.model.parameters(), lr=self.params.learning_rate)

    def learn(self, y_target, mgc, batch_size):
        # prepare batches
        x_list, y_list, c_list = _create_batches(y_target, mgc, batch_size, UPSAMPLE_COUNT=self.UPSAMPLE_COUNT,
                                                 mgc_order=self.params.mgc_order)
        if len(x_list) == 0:
            return 0
        # learn
        total_loss = 0
        for x, y, c in tqdm.tqdm(zip(x_list, y_list, c_list), total=len(c_list)):
            x = torch.tensor(x, dtype=torch.float32).to(device)
            y = torch.tensor(y, dtype=torch.float32).to(device)
            c = torch.tensor(c, dtype=torch.float32).to(device)
            self.trainer.zero_grad()
            y_hat = self.model(x, c)

            t_y = y[:, 1:]  # .reshape(1, y_hat.shape[0] * y_hat.shape[2] - 1, 1)
            p_y = y_hat[:, :, :-1]

            loss = self.loss(p_y, t_y, size_average=True)
            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.)
            self.trainer.step()
        return total_loss / len(x_list)

    def synthesize(self, mgc, batch_size, temperature=1.0):
        num_samples = len(mgc) * self.UPSAMPLE_COUNT
        # from ipdb import set_trace
        # set_trace()
        with torch.no_grad():
            c = torch.tensor(mgc.transpose(), dtype=torch.float32).to(device).reshape(1, mgc[0].shape[0], len(mgc))
            x = self.model.generate(num_samples - 1, c, device=device, temperature=temperature)
        torch.cuda.synchronize()
        x = x.squeeze().numpy() * 32768
        return x.astype('int16')

    def store(self, output_base):
        torch.save(self.model.state_dict(), output_base + ".network")

    def load(self, output_base):
        self.model.load_state_dict(torch.load(output_base + ".network", map_location=device))
        self.model.to(device)


class ClarinetVocoder:
    def __init__(self, params, vocoder=None):
        self.UPSAMPLE_COUNT = 256
        self.RECEPTIVE_SIZE = 3 * 3 * 3 * 3 * 3 * 3
        self.params = params
        self.model_t = vocoder.model
        self.model_s = Wavenet_Student(num_blocks_student=[1, 1, 1, 4],
                                       num_layers=6, cin_channels=self.params.mgc_order)
        self.model_s.to(device)

        # self.stft = STFT(filter_length=1024, hop_length=256).to(device)
        self.criterion_t = KL_Loss().to(device)
        self.criterion_frame = torch.nn.MSELoss().to(device)
        self.trainer = torch.optim.Adam(self.model_s.parameters(), lr=self.params.learning_rate)
        self.model_t.eval()
        self.model_s.train()

    def learn(self, y_target, mgc, batch_size):
        # prepare batches
        self.model_t.eval()
        self.model_s.train()
        x_list, y_list, c_list = _create_batches(y_target, mgc, batch_size, UPSAMPLE_COUNT=self.UPSAMPLE_COUNT,
                                                 mgc_order=self.params.mgc_order)
        if len(x_list) == 0:
            return 0
        # learn
        total_loss = 0
        for x, y, c in tqdm.tqdm(zip(x_list, y_list, c_list), total=len(c_list)):
            x = torch.tensor(x, dtype=torch.float32).to(device)
            y = torch.tensor(y, dtype=torch.float32).to(device)
            c = torch.tensor(c, dtype=torch.float32).to(device)
            self.trainer.zero_grad()
            q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
            z = q_0.sample()
            # with torch.no_grad():
            c_up = self.model_t.upsample(c).detach()
            # from ipdb import set_trace
            # set_trace()
            x_student, mu_s, logs_s = self.model_s(z, c_up)
            mu_logs_t = self.model_t(x_student, c)

            loss_t, loss_KL, loss_reg = self.criterion_t(mu_s, logs_s, mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1],
                                                         size_average=True)
            # loss_t, loss_KL, loss_reg = self.criterion_t(mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1], mu_s, logs_s, size_average=False)
            # stft_student, _ = #self.stft(x_student[:, :, 1:])
            # stft_truth, _ = #self.stft(x[:, :, 1:])
            stft_student = stft(x_student[:, 0, 1:], scale='linear')
            stft_truth = stft(x[:, 0, 1:], scale='linear')
            loss_frame = self.criterion_frame(stft_student, stft_truth.detach())
            # from ipdb import set_trace
            # set_trace()
            loss_tot = loss_t + loss_frame
            total_loss += loss_tot.item()
            loss_tot.backward()

            torch.nn.utils.clip_grad_norm_(self.model_s.parameters(), 10.)
            self.trainer.step()
            del loss_tot, loss_frame, loss_KL, loss_reg, loss_t, x, y, c, c_up, stft_student, stft_truth, q_0, z
            del x_student, mu_s, logs_s, mu_logs_t

        return total_loss / len(x_list)

    def synthesize(self, mgc, batch_size, temperature=1.0):
        num_samples = len(mgc) * self.UPSAMPLE_COUNT
        zeros = np.zeros((1, 1, num_samples))
        ones = np.ones((1, 1, num_samples))
        with torch.no_grad():
            c = torch.tensor(mgc.transpose(), dtype=torch.float32).to(device).reshape(1, mgc[0].shape[0], len(mgc))
            c_up = self.model_t.upsample(c)
            q_0 = Normal(torch.tensor(zeros, dtype=torch.float32).to(device),
                         torch.tensor(ones, dtype=torch.float32).to(device))
            z = q_0.sample() * temperature
            x = self.model_s.generate(z, c_up, device=device)
        torch.cuda.synchronize()
        x = x.squeeze().cpu().numpy() * 32768
        return x.astype('int16')

    def store(self, output_base):
        torch.save(self.model_s.state_dict(), output_base + ".network")

    def load(self, output_base):
        self.model_s.load_state_dict(torch.load(output_base + ".network", map_location=device))
        self.model_s.to(device)


class WaveGlowVocoder:
    def __init__(self, params):
        self.waveglow = None
        self.denoiser = None

    def learn(self, y_target, mgc, batch_size):
        # prepare batches
        self.model_t.eval()
        self.model_s.train()
        x_list, y_list, c_list = _create_batches(y_target, mgc, batch_size, UPSAMPLE_COUNT=self.UPSAMPLE_COUNT,
                                                 mgc_order=self.params.mgc_order)
        if len(x_list) == 0:
            return 0
        # learn
        total_loss = 0
        for x, y, c in tqdm.tqdm(zip(x_list, y_list, c_list), total=len(c_list)):
            x = torch.tensor(x, dtype=torch.float32).to(device)
            y = torch.tensor(y, dtype=torch.float32).to(device)
            c = torch.tensor(c, dtype=torch.float32).to(device)
            self.trainer.zero_grad()
            q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
            z = q_0.sample()
            # with torch.no_grad():
            c_up = self.model_t.upsample(c).detach()
            # from ipdb import set_trace
            # set_trace()
            x_student, mu_s, logs_s = self.model_s(z, c_up)
            mu_logs_t = self.model_t(x_student, c)

            loss_t, loss_KL, loss_reg = self.criterion_t(mu_s, logs_s, mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1],
                                                         size_average=True)
            # loss_t, loss_KL, loss_reg = self.criterion_t(mu_logs_t[:, 0:1, :-1], mu_logs_t[:, 1:, :-1], mu_s, logs_s, size_average=False)
            # stft_student, _ = #self.stft(x_student[:, :, 1:])
            # stft_truth, _ = #self.stft(x[:, :, 1:])
            stft_student = stft(x_student[:, 0, 1:], scale='linear')
            stft_truth = stft(x[:, 0, 1:], scale='linear')
            loss_frame = self.criterion_frame(stft_student, stft_truth.detach())
            # from ipdb import set_trace
            # set_trace()
            loss_tot = loss_t + loss_frame
            total_loss += loss_tot.item()
            loss_tot.backward()

            torch.nn.utils.clip_grad_norm_(self.model_s.parameters(), 10.)
            self.trainer.step()
            del loss_tot, loss_frame, loss_KL, loss_reg, loss_t, x, y, c, c_up, stft_student, stft_truth, q_0, z
            del x_student, mu_s, logs_s, mu_logs_t

        return total_loss / len(x_list)

    def synthesize(self, mgc, batch_size, temperature=1.0):
        mel = mgc
        mel = torch.autograd.Variable(torch.tensor(mel).cuda().float()).transpose(0, 1)
        mel = torch.unsqueeze(mel, 0)
        mel = torch.log10(mel) * 20
        # from ipdb import set_trace
        # set_trace()
        with torch.no_grad():
            audio = self.waveglow.infer(mel, sigma=temperature)
            audio = audio * 32768
        audio = audio.squeeze()
        audio = audio.cpu().numpy()
        from scipy import signal
        audio = signal.lfilter([1.0], [1.0, -0.97], audio)
        audio = audio.astype('int16')
        return audio

    def store(self, output_base):
        pass  # torch.save(self.model_s.state_dict(), output_base + ".network")

    def load(self, output_base):
        import sys
        sys.path.insert(0, 'cube/models/waveglow')
        self.waveglow = torch.load(output_base)['model']
        self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
        self.waveglow.cuda().eval()
        # if denoiser_strength > 0:
        #    self.denoiser = Denoiser(self.waveglow).cuda()