from functional.utils import TrainClock
import os
import torch
import torch.optim as optim
import torch.nn as nn
from abc import abstractmethod


class BaseAgent(object):
    def __init__(self, config, net):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.net = net
        self.clock = TrainClock()
        self.device = config.device

        self.use_triplet = config.use_triplet
        self.use_footvel_loss = config.use_footvel_loss

        # set loss function
        self.mse = nn.MSELoss()
        self.tripletloss = nn.TripletMarginLoss(margin=config.triplet_margin)
        self.triplet_weight = config.triplet_weight
        self.foot_idx = config.foot_idx
        self.footvel_loss_weight = config.footvel_loss_weight

        # set optimizer
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, 0.99)

    def save_network(self, name=None):
        if name is None:
            save_path = os.path.join(self.model_dir, "model_epoch{}.pth".format(self.clock.epoch))
        else:
            save_path = os.path.join(self.model_dir, name)
        torch.save(self.net.cpu().state_dict(), save_path)
        self.net.to(self.device)

    def load_network(self, epoch):
        load_path = os.path.join(self.model_dir, "model_epoch{}.pth".format(epoch))
        state_dict = torch.load(load_path)
        self.net.load_state_dict(state_dict)

    @abstractmethod
    def forward(self, data):
        pass

    def update_network(self, loss_dcit):
        loss = sum(loss_dcit.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_learning_rate(self):
        self.scheduler.step(self.clock.epoch)

    def train_func(self, data):
        self.net.train()

        outputs, losses = self.forward(data)

        self.update_network(losses)

        return outputs, losses

    def val_func(self, data):
        self.net.eval()

        with torch.no_grad():
            outputs, losses = self.forward(data)

        return outputs, losses