import torch import torch.nn.functional as F def apply_filter(feat, filter, dilation_factors=None): """Applies the filter on the input features (feat). The number of groups is automatically calculated. args: feat: These are the input features. Must have dimensions (images_in_sequence, sequences, feat_dim, H, W) filter: The filter to apply. Must have dimensions (sequences, feat_dim, fH, fW) or (sequences, filters, feat_dim/groups, fH, fW) output: scores: Output of filtering. Dimensions (images_in_sequence, sequences, yH, yW) or (images_in_sequence, sequences, filters, yH, yW) """ multiple_filters = (filter.dim() == 5) padding = (filter.shape[-2] // 2, filter.shape[-1] // 2) num_images = feat.shape[0] num_sequences = feat.shape[1] if feat.dim() == 5 else 1 num_filters = filter.shape[1] if multiple_filters else 1 num_channels = feat.shape[-3] groups = num_channels // filter.shape[-3] assert num_filters % groups == 0 and num_channels % groups == 0 if multiple_filters: if dilation_factors is None: scores = F.conv2d(feat.reshape(num_images, -1, feat.shape[-2], feat.shape[-1]), filter.view(-1, *filter.shape[-3:]), padding=padding, groups=num_sequences*groups) return scores.view(num_images, num_sequences, -1, scores.shape[-2], scores.shape[-1]) else: scores_all = [] start_id = 0 for d_factor, num_filters_with_d in dilation_factors.items(): f_d = filter[:, start_id:start_id+num_filters_with_d, ...].contiguous() padding_d = [p+d_factor-1 for p in padding] scores_d = F.conv2d(feat.reshape(num_images, -1, feat.shape[-2], feat.shape[-1]), f_d.view(-1, *f_d.shape[-3:]), padding=padding_d, groups=num_sequences * groups, dilation=d_factor) scores_d = scores_d.view(num_images, num_sequences, -1, scores_d.shape[-2], scores_d.shape[-1]) scores_all.append(scores_d) start_id += num_filters_with_d scores = torch.cat(scores_all, dim=2) return scores scores = F.conv2d(feat.reshape(num_images, -1, feat.shape[-2], feat.shape[-1]), filter, padding=padding, groups=num_sequences) return scores.view(num_images, num_sequences, scores.shape[-2], scores.shape[-1]) def apply_feat_transpose(feat, input, filter_ksz, training=True, groups=1): """Applies the transposed operation off apply_filter w.r.t. filter itself. Can be used to compute the filter gradient. args: feat: These are the input features. Must have dimensions (images_in_sequence, sequences, feat_dim, H, W) input: Input activation (e.g. residuals). Must have dimensions (images_in_sequence, sequences, yH, yW) or (images_in_sequence, sequences, filters, yH, yW) training: Choose the faster implementation whether training or not. output: Output of transposed operation. Dimensions (sequences, feat_dim, fH, fW) """ if groups != 1: raise NotImplementedError('Not implemented other values of group.') if training or input.dim() == 5: return _apply_feat_transpose_v3(feat, input, filter_ksz) return _apply_feat_transpose_v2(feat, input, filter_ksz) def _apply_feat_transpose_v1(feat, input, filter_ksz): """This one is slow as hell!!!!""" num_images = feat.shape[0] num_sequences = feat.shape[1] if feat.dim() == 5 else 1 feat_sz = (feat.shape[-2], feat.shape[-1]) if isinstance(filter_ksz, int): filter_ksz = (filter_ksz, filter_ksz) # trans_pad = sz + padding - filter_ksz trans_pad = [sz + ksz//2 - ksz for sz, ksz in zip(feat_sz, filter_ksz)] filter_grad = F.conv_transpose2d(input.flip((2, 3)).view(1, -1, input.shape[-2], input.shape[-1]), feat.reshape(-1, feat.shape[-3], feat.shape[-2], feat.shape[-1]), padding=trans_pad, groups=num_images * num_sequences) return filter_grad.view(num_images, num_sequences, -1, filter_grad.shape[-2], filter_grad.shape[-1]).sum(dim=0) def _apply_feat_transpose_v2(feat, input, filter_ksz): """Fast forward and slow backward""" multiple_filters = (input.dim() == 5) num_images = feat.shape[0] num_sequences = feat.shape[1] if feat.dim() == 5 else 1 num_filters = input.shape[2] if multiple_filters else 1 if isinstance(filter_ksz, int): filter_ksz = (filter_ksz, filter_ksz) trans_pad = [(ksz-1)//2 for ksz in filter_ksz] if multiple_filters: filter_grad = F.conv2d(input.reshape(-1, num_filters, input.shape[-2], input.shape[-1]).permute(1,0,2,3), feat.reshape(-1, 1, feat.shape[-2], feat.shape[-1]), padding=trans_pad, groups=num_images * num_sequences) if num_images == 1: return filter_grad.view(num_filters, num_sequences, -1, filter_grad.shape[-2], filter_grad.shape[-1]).flip((3,4)).permute(1,0,2,3,4) return filter_grad.view(num_filters, num_images, num_sequences, -1, filter_grad.shape[-2], filter_grad.shape[-1]).sum(dim=1).flip((3,4)).permute(1,0,2,3,4) filter_grad = F.conv2d(input.reshape(1, -1, input.shape[-2], input.shape[-1]), feat.reshape(-1, 1, feat.shape[-2], feat.shape[-1]), padding=trans_pad, groups=num_images * num_sequences) return filter_grad.view(num_images, num_sequences, -1, filter_grad.shape[-2], filter_grad.shape[-1]).sum(dim=0).flip((2,3)) def _apply_feat_transpose_v3(feat, input, filter_ksz): """Slow forward fast backward""" multiple_filters = (input.dim() == 5) num_images = feat.shape[0] num_sequences = feat.shape[1] if feat.dim() == 5 else 1 num_filters = input.shape[2] if multiple_filters else 1 if isinstance(filter_ksz, int): filter_ksz = (filter_ksz, filter_ksz) trans_pad = [ksz//2 for ksz in filter_ksz] filter_grad = F.conv2d(feat.reshape(-1, feat.shape[-3], feat.shape[-2], feat.shape[-1]).permute(1,0,2,3), input.reshape(-1, 1, input.shape[-2], input.shape[-1]), padding=trans_pad, groups=num_images * num_sequences) if multiple_filters: if num_images == 1: return filter_grad.view(-1, num_sequences, num_filters, filter_grad.shape[-2], filter_grad.shape[-1]).permute(1,2,0,3,4) return filter_grad.view(-1, num_images, num_sequences, num_filters, filter_grad.shape[-2], filter_grad.shape[-1]).sum(dim=1).permute(1,2,0,3,4) if num_images == 1: return filter_grad.permute(1,0,2,3) return filter_grad.view(-1, num_images, num_sequences, filter_grad.shape[-2], filter_grad.shape[-1]).sum(dim=1).permute(1,0,2,3) def _apply_feat_transpose_v4(feat, input, filter_ksz): """Slow forward fast backward""" num_images = feat.shape[0] num_sequences = feat.shape[1] if feat.dim() == 5 else 1 if isinstance(filter_ksz, int): filter_ksz = (filter_ksz, filter_ksz) trans_pad = [ksz//2 for ksz in filter_ksz] filter_grad = F.conv2d(feat.permute(2,1,0,3,4).reshape(feat.shape[-3], -1, feat.shape[-2], feat.shape[-1]), input.permute(1,0,2,3), padding=trans_pad, groups=num_sequences) return filter_grad.permute(1,0,2,3) def filter_gradient(feat, filter, label=None, training=True): """Computes gradient of the filter when applied on the input features and ground truth label. args: feat: These are the input features. Must have dimensions (images_in_sequence, sequences, feat_dim, H, W) filter: The filter to apply. Must have dimensions (sequences, feat_dim, fH, fW) label: Ground truth label in the L2 loss. Dimensions (images_in_sequence, sequences, yH, yW) output: filter_gradient: Dimensions same as input filter (sequences, feat_dim, fH, fW) """ residuals = apply_filter(feat, filter) if label is not None: residuals = residuals - label filter_ksz = (filter.shape[-2], filter.shape[-1]) return apply_feat_transpose(feat, residuals, filter_ksz, training=training)