import torch
from util.util import tensor2im
from models.base_model import BaseModel

################## SeasonTransfer #############################
class SeasonTransferModel(BaseModel):
    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        
    def prepare_data(self, data):
        img, attr_source = data
        img = torch.cat(img,0).to(self.device)
        batch_size = img.size(0)
        attr_source = torch.cat(attr_source,0).to(self.device)
        index_target = torch.tensor(range(-batch_size//2,batch_size//2)).to(self.device)
        weight_source = torch.ones([batch_size,1]).to(self.device)
        self.current_data = [img, attr_source, index_target,weight_source]
        return self.current_data
            
    def translation(self, data):
        with torch.no_grad():
            self.prepare_data(data)
            img, attr_source, index_target, _ = self.current_data
            batch_size = img.size(0)
            assert batch_size == 2
            style_enc, _, _ = self.enc_style(img)
            style_target_enc = style_enc[index_target]
            attr_target = attr_source[index_target]
            content = self.enc_content(img)
            results_s2w, results_w2s = [('input_summer',tensor2im(img[0].data))], [('input_winter',tensor2im(img[1].data))]
            fakes = self.dec(content,torch.cat([attr_target,style_target_enc],dim=1))
            results_s2w.append(('s2w_enc',tensor2im(fakes[0].data)))
            results_w2s.append(('w2s_enc',tensor2im(fakes[1].data)))
            for i in range(self.opt.n_samples):
                style_rand = self.sample_latent_code(style_enc.size())
                fakes = self.dec(content,torch.cat([attr_target,style_rand],dim=1))
                results_s2w.append(('s2w_rand_{}'.format(i+1),tensor2im(fakes[0].data)))
                results_w2s.append(('w2s_rand_{}'.format(i+1),tensor2im(fakes[1].data)))
            return  results_s2w+results_w2s