""" ================= This model is a Recurrent model for Inferring the action of the dog. Given a sequence of images, infer the imu changes corresponding to those frames. ================= """ import torch import torch.nn as nn import pdb import torch.nn.functional as F import torch.optim as optim import numpy as np from torchvision.models import alexnet as torchvision_alexnet from extensions.multi_label_cross_entropy import MultiLabelCrossEntropyLoss from .basemodel import BaseModel from torch.autograd import Variable from .lstm import Lstm from training import metrics import pdb class LstmImg2LastImus(BaseModel): metric = [ metrics.SequenceMultiClassMetric, metrics.AllAtOnce, metrics.AngleEvaluationMetric, ] def __init__(self, args): super(LstmImg2LastImus, self).__init__() self.input_length = args.input_length self.output_length = args.output_length self.class_weights = args.dataset.CLASS_WEIGHTS[torch.LongTensor( args.imus)] self.embedding_input = nn.Linear(args.image_feature, args.hidden_size) self.lstm = Lstm(args) def forward(self, input, target): input = input[:, :self.input_length] output_indices = list( range(target.size(1) - self.output_length, target.size(1))) target = target[:, -self.output_length:] input = input.transpose(0, 1) embedded_input = self.embedding_input(input) full_output = self.lstm(embedded_input, target=None) return full_output.transpose( 0, 1), target, torch.LongTensor(output_indices) def loss(self): return MultiLabelCrossEntropyLoss(self.class_weights) def optimizer(self): return optim.Adam(self.parameters(), lr=0.001) def perplexity(self, input, target): input = input[:, :self.input_length] target = target[:, -self.output_length:] input = input.transpose(0, 1) embedded_input = self.embedding_input(input) # Get probabilities with teacher forcing. probabilities = self.lstm(embedded_input, target) probabilities = probabilities.transpose(0, 1) gt_probabilities = probabilities.gather(3, target.unsqueeze(3)).squeeze(3) gt_avg_log_probability = gt_probabilities.log().mean(1) perplexity = gt_avg_log_probability.exp().mean(0) return 100 * perplexity def learning_rate(self, epoch): base_lr = 0.001 decay_rate = 0.1 step = 90 assert 1 <= epoch if 1 <= epoch <= step: return base_lr elif step <= epoch <= step * 2: return base_lr * decay_rate elif step * 2 <= epoch <= step * 3: return base_lr * decay_rate * decay_rate else: return base_lr * decay_rate * decay_rate * decay_rate