import torch import torch.nn as nn from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence class HardMaxOp: @staticmethod def max(X): M, _ = torch.max(X, dim=2, keepdim=True) A = (M == X).float() A = A / torch.sum(A, dim=2, keepdim=True) return M.squeeze(), A.squeeze() @staticmethod def hessian_product(P, Z): return torch.zeros_like(Z) class SoftMaxOp: @staticmethod def max(X): M, _ = torch.max(X, dim=2) X = X - M[:, :, None] A = torch.exp(X) S = torch.sum(A, dim=2) M = M + torch.log(S) A /= S[:, :, None] return M.squeeze(), A.squeeze() @staticmethod def hessian_product(P, Z): prod = P * Z return prod - P * torch.sum(prod, dim=2, keepdim=True) class SparseMaxOp: @staticmethod def max(X): seq_len, n_batch, n_states = X.shape X_sorted, _ = torch.sort(X, dim=2, descending=True) cssv = torch.cumsum(X_sorted, dim=2) - 1 ind = X.new(n_states) for i in range(n_states): ind[i] = i + 1 cond = X_sorted - cssv / ind > 0 rho = cond.long().sum(dim=2) cssv = cssv.view(-1, n_states) rho = rho.view(-1) tau = (torch.gather(cssv, dim=1, index=rho[:, None] - 1)[:, 0] / rho.type(X.type())) tau = tau.view(seq_len, n_batch) A = torch.clamp(X - tau[:, :, None], min=0) # A /= A.sum(dim=2, keepdim=True) M = torch.sum(A * (X - .5 * A), dim=2) return M.squeeze(), A.squeeze() @staticmethod def hessian_product(P, Z): S = (P > 0).type(Z.type()) support = torch.sum(S, dim=2, keepdim=True) prod = S * Z return prod - S * torch.sum(prod, dim=2, keepdim=True) / support operators = {'softmax': SoftMaxOp, 'sparsemax': SparseMaxOp, 'hardmax': HardMaxOp} def _topological_loop(theta, batch_sizes, operator='softmax', adjoint=False, Q=None, Qt=None): operator = operators[operator] new = theta.new B = batch_sizes[0].item() T = len(batch_sizes) L, S, _ = theta.size() if adjoint: Qd = new(L + B, S, S).zero_() Qtd = new(B, S).zero_() Vd = new(L + B, S).zero_() Vdt = new(B).zero_() else: Q = new(L + B, S, S).zero_() Qt = new(B, S).zero_() V = new(L + B, S).zero_() Vt = new(B).zero_() left = B term_right = B prev_length = B for t in range(T + 1): if t == T: cur_length = 0 else: cur_length = batch_sizes[t] right = left + cur_length prev_left = left - prev_length prev_cut = right - prev_length len_term = prev_length - cur_length if cur_length != 0: # -B account for padding if adjoint: M = (theta[left - B:right - B] + Vd[prev_left:prev_cut][:, None, :]) Vd[left:right] = torch.sum(Q[left:right] * M, dim=2) Qd[left:right] = operator.hessian_product(Q[left:right], M) else: M = (theta[left - B:right - B] + V[prev_left:prev_cut][:, None, :]) V[left:right], Q[left:right] = operator.max(M) term_left = term_right - len_term if len_term != 0: if adjoint: M = Vd[prev_cut:left] Vdt[term_left:term_right] = torch.sum( Qt[term_left:term_right] * M) Qtd[term_left:term_right] = operator.hessian_product( Qt[term_left:term_right][:, None, :], M[:, None, :])[:, 0] else: M = V[prev_cut:left] Vt[term_left:term_right], Qt[term_left:term_right] \ = operator.max(M[:, None, :]) term_right = term_left left = right prev_length = cur_length if adjoint: return Vdt, Qd, Qtd else: return Vt, Q, Qt def _reverse_loop(Q, Qt, Ut, batch_sizes, adjoint=False, U=None, Qd=None, Qdt=None): new = Q.new B = batch_sizes[0].item() T = len(batch_sizes) L, S, _ = Q.size() L = L - B if adjoint: Ed = new(L, S, S).zero_() Ud = new(L + B, S).zero_() Udt = new(B).zero_() else: E = new(L, S, S).zero_() U = new(L + B, S).zero_() # Ut = Ut right = L + B term_left = 0 prev_length = 0 off_right = L for t in reversed(range(-1, T)): if t == -1: cur_length = B else: cur_length = batch_sizes[t] left = right - cur_length off_left = off_right - prev_length cut = left + prev_length len_term = cur_length - prev_length if prev_length != 0: prev_left, prev_cut = right, right + prev_length if adjoint: Ed[off_left:off_right] = (Q[prev_left:prev_cut] * Ud[prev_left:prev_cut][:, :, None] + Qd[prev_left:prev_cut] * U[prev_left:prev_cut][:, :, None]) Ud[left:cut] = torch.sum(Ed[off_left:off_right], dim=1) else: E[off_left:off_right] = (Q[prev_left:prev_cut] * U[prev_left:prev_cut][:, :, None]) U[left:cut] = torch.sum(E[off_left:off_right], dim=1) term_right = term_left + len_term if len_term > 0: if adjoint: Ud[cut:right] = (Qt[term_left:term_right] * Udt[term_left:term_right][:, None] + Qdt[term_left:term_right] * Ut[term_left:term_right][:, None]) else: U[cut:right] = (Qt[term_left:term_right] * Ut[term_left:term_right][:, None]) term_left = term_right right = left off_right = off_left prev_length = cur_length if not adjoint: return E, U, Ut else: return Ed, Ud, Udt class ViterbiFunction(torch.autograd.Function): @staticmethod def forward(ctx, theta, batch_sizes, operator): Vt, Q, Qt = _topological_loop(theta, batch_sizes, operator=operator, adjoint=False) ctx.save_for_backward(theta, Q, Qt) ctx.others = batch_sizes, operator return Vt @staticmethod def backward(ctx, M): theta, Q, Qt = ctx.saved_tensors batch_sizes, operator = ctx.others return ViterbiFunctionBackward.apply(theta, M, Q, Qt, batch_sizes, operator), None, None class ViterbiFunctionBackward(torch.autograd.Function): @staticmethod def forward(ctx, theta, M, Q, Qt, batch_sizes, operator): E, U, Ut = _reverse_loop(Q, Qt, M, batch_sizes, adjoint=False) ctx.save_for_backward(Q, Qt, U, Ut) ctx.others = batch_sizes, operator return E @staticmethod def backward(ctx, Z): Q, Qt, U, Ut = ctx.saved_tensors batch_sizes, operator = ctx.others Vdt, Qd, Qdt = _topological_loop(Z, batch_sizes, operator=operator, adjoint=True, Q=Q, Qt=Qt) Ed, _, _ = _reverse_loop(Q, Qt, Ut, batch_sizes, adjoint=True, Qd=Qd, Qdt=Qdt, U=U) return Ed, Vdt, None, None, None, None class PackedViterbi(nn.Module): def __init__(self, operator): super().__init__() self.operator = operator def forward(self, theta): return ViterbiFunction.apply(theta.data, theta.batch_sizes, self.operator) def decode(self, theta): """Shortcut for doing inference """ data, batch_sizes = theta with torch.enable_grad(): data.requires_grad_() nll = self.forward(theta) v = torch.sum(nll) v_grad, = torch.autograd.grad(v, (data,), create_graph=True) return PackedSequence(v_grad, batch_sizes) class Viterbi(nn.Module): def __init__(self, operator): super().__init__() self.packed_viterbi = PackedViterbi(operator=operator) def _pack(self, theta, lengths): T, B, S, _ = theta.shape if lengths is None: data = theta.view(T * B, S, S) batch_sizes = torch.LongTensor(T, device=theta.device).fill_(B) else: data, batch_sizes = pack_padded_sequence(theta, lengths) return PackedSequence(data, batch_sizes) def forward(self, theta, lengths=None): return self.packed_viterbi(self._pack(theta, lengths)) def decode(self, theta, lengths=None): return self.packed_viterbi.decode(self._pack(theta, lengths))