# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F import torchvision import time cfg = {'PicaNet': "GGLLL", 'Size': [28, 28, 28, 56, 112, 224], 'Channel': [1024, 512, 512, 256, 128, 64], 'loss_ratio': [0.5, 0.5, 0.5, 0.8, 0.8, 1]} class Unet(nn.Module): def __init__(self, cfg={'PicaNet': "GGLLL", 'Size': [28, 28, 28, 56, 112, 224], 'Channel': [1024, 512, 512, 256, 128, 64], 'loss_ratio': [0.5, 0.5, 0.5, 0.8, 0.8, 1]}): super(Unet, self).__init__() self.encoder = Encoder() self.decoder = nn.ModuleList() self.cfg = cfg for i in range(5): assert cfg['PicaNet'][i] == 'G' or cfg['PicaNet'][i] == 'L' self.decoder.append( DecoderCell(size=cfg['Size'][i], in_channel=cfg['Channel'][i], out_channel=cfg['Channel'][i + 1], mode=cfg['PicaNet'][i])) self.decoder.append(DecoderCell(size=cfg['Size'][5], in_channel=cfg['Channel'][5], out_channel=1, mode='C')) def forward(self, *input): if len(input) == 2: x = input[0] tar = input[1] test_mode = False elif len(input) == 3: x = input[0] tar = input[1] test_mode = input[2] elif len(input) == 1: x = input[0] tar = None test_mode = True else: assert 0 en_out = self.encoder(x) dec = None pred = [] attention = [] for i in range(6): # print(En_out[5 - i].size()) dec, _pred, _attention = self.decoder[i](en_out[5 - i], dec) pred.append(_pred) if _attention is not None: attention.append(_attention) loss = 0 if not test_mode: for i in range(6): loss += F.binary_cross_entropy(pred[5 - i], tar) * self.cfg['loss_ratio'][5 - i] # print(float(loss)) if tar.size()[2] > 28: tar = F.max_pool2d(tar, 2, 2) return pred, loss, attention def make_layers(cfg, in_channels): layers = [] dilation_flag = False for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] elif v == 'm': layers += [nn.MaxPool2d(kernel_size=1, stride=1)] dilation_flag = True else: if not dilation_flag: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=2, dilation=2) layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) # [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() configure = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'm', 512, 512, 512, 'm'] self.seq = make_layers(configure, 3) self.conv6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=12, dilation=12) # fc6 in paper self.conv7 = nn.Conv2d(1024, 1024, 3, 1, 1) # fc7 in paper def forward(self, *input): x = input[0] conv1 = self.seq[:4](x) conv2 = self.seq[4:9](conv1) conv3 = self.seq[9:16](conv2) conv4 = self.seq[16:23](conv3) conv5 = self.seq[23:](conv4) conv6 = self.conv6(conv5) conv7 = self.conv7(conv6) return conv1, conv2, conv3, conv4, conv5, conv7 class DecoderCell(nn.Module): def __init__(self, size, in_channel, out_channel, mode): super(DecoderCell, self).__init__() self.bn_en = nn.BatchNorm2d(in_channel) self.conv1 = nn.Conv2d(2 * in_channel, in_channel, kernel_size=1, padding=0) self.mode = mode if mode == 'G': self.picanet = PicanetG(size, in_channel) elif mode == 'L': self.picanet = PicanetL(in_channel) elif mode == 'C': self.picanet = None else: assert 0 if not mode == 'C': self.conv2 = nn.Conv2d(2 * in_channel, out_channel, kernel_size=1, padding=0) self.bn_feature = nn.BatchNorm2d(out_channel) self.conv3 = nn.Conv2d(out_channel, 1, kernel_size=1, padding=0) else: self.conv2 = nn.Conv2d(in_channel, 1, kernel_size=1, padding=0) def forward(self, *input): assert len(input) <= 2 if input[1] is None: en = input[0] dec = input[0] # not specified in paper else: en = input[0] dec = input[1] if dec.size()[2] * 2 == en.size()[2]: dec = F.interpolate(dec, scale_factor=2, mode='bilinear', align_corners=True) elif dec.size()[2] != en.size()[2]: assert 0 en = self.bn_en(en) en = F.relu(en) fmap = torch.cat((en, dec), dim=1) # F fmap = self.conv1(fmap) fmap = F.relu(fmap) if not self.mode == 'C': # print(fmap.size()) fmap_att, attention = self.picanet(fmap) # F_att x = torch.cat((fmap, fmap_att), 1) x = self.conv2(x) x = self.bn_feature(x) dec_out = F.relu(x) _y = self.conv3(dec_out) _y = torch.sigmoid(_y) else: dec_out = self.conv2(fmap) _y = torch.sigmoid(dec_out) attention = None return dec_out, _y, attention class PicanetG(nn.Module): def __init__(self, size, in_channel): super(PicanetG, self).__init__() self.renet = Renet(size, in_channel, 100) self.in_channel = in_channel def forward(self, *input): x = input[0] size = x.size() kernel = self.renet(x) kernel = F.softmax(kernel, 1) # print(kernel.size()) x = F.unfold(x, [10, 10], dilation=[3, 3]) x = x.reshape(size[0], size[1], 10 * 10) kernel = kernel.reshape(size[0], 100, -1) x = torch.matmul(x, kernel) x = x.reshape(size[0], size[1], size[2], size[3]) # for attention visualization # print(torch.cuda.memory_allocated() / 1024 / 1024) attention = kernel.data attention = attention.requires_grad_(False) attention = torch.reshape(attention, (size[0], -1, 10, 10)) # attention = F.conv_transpose2d(torch.ones((1, 1, 1, 1)).cuda(), attention, dilation=3) attention = F.interpolate(attention, 224, mode='bilinear', align_corners=True) # attention = F.interpolate(attention, 224, mode='area') attention = torch.reshape(attention, (size[0], size[2], size[3], 224, 224)) return x, attention class PicanetL(nn.Module): def __init__(self, in_channel): super(PicanetL, self).__init__() self.conv1 = nn.Conv2d(in_channel, 128, kernel_size=7, dilation=2, padding=6) self.conv2 = nn.Conv2d(128, 49, kernel_size=1) def forward(self, *input): x = input[0] size = x.size() kernel = self.conv1(x) kernel = self.conv2(kernel) kernel = F.softmax(kernel, 1) attention = kernel.data kernel = kernel.reshape(size[0], 1, size[2] * size[3], 7 * 7) # print("Before unfold", x.shape) x = F.unfold(x, kernel_size=[7, 7], dilation=[2, 2], padding=6) # print("After unfold", x.shape) x = x.reshape(size[0], size[1], size[2] * size[3], -1) # print(x.shape, kernel.shape) x = torch.mul(x, kernel) x = torch.sum(x, dim=3) x = x.reshape(size[0], size[1], size[2], size[3]) # attention = kernel.data attention = attention.requires_grad_(False) # attention = torch.reshape(attention, (size[0], -1, 7, 7)) attention = torch.reshape(attention, (size[0], -1, 7, 7)) # attention = F.conv_transpose2d(torch.ones((1, 1, 1, 1)).cuda(), attention, dilation=2) attention = F.interpolate(attention, int(12 * 224 / size[2] + 1), mode='bilinear', align_corners=True) # attention = F.interpolate(attention, int(12 * 224 / size[2] + 1), mode='area') attention = torch.reshape(attention, (size[0], size[2], size[3], int(12 * 224 / size[2] + 1), int(12 * 224 / size[2] + 1))) # attention = attention.permute(0, 2, 1, 3, 4) # attention = attention.permute(0, 1, 2, 4, 3) return x, attention class Renet(nn.Module): def __init__(self, size, in_channel, out_channel): super(Renet, self).__init__() self.size = size self.in_channel = in_channel self.out_channel = out_channel self.vertical = nn.LSTM(input_size=in_channel, hidden_size=256, batch_first=True, bidirectional=True) # each row self.horizontal = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=True) # each column self.conv = nn.Conv2d(512, out_channel, 1) # self.fc = nn.Linear(512 * size * size, 10) def forward(self, *input): x = input[0] temp = [] size = x.size() # batch, in_channel, height, width x = torch.transpose(x, 1, 3) # batch, width, height, in_channel for i in range(self.size): h, _ = self.vertical(x[:, :, i, :]) temp.append(h) # batch, width, 512 x = torch.stack(temp, dim=2) # batch, width, height, 512 temp = [] for i in range(self.size): h, _ = self.horizontal(x[:, i, :, :]) temp.append(h) # batch, width, 512 x = torch.stack(temp, dim=3) # batch, height, 512, width x = torch.transpose(x, 1, 2) # batch, 512, height, width # x = torch.reshape(x, (-1, 512 * self.size * self.size)) x = self.conv(x) return x if __name__ == '__main__': vgg = torchvision.models.vgg16(pretrained=True) # model = Encoder() # model.seq.load_state_dict(vgg.features.state_dict()) # print(model.state_dict().keys()) # print(vgg.features.state_dict().keys()) # print(vgg.features) device = torch.device("cuda") batch_size = 1 noise = torch.randn((batch_size, 3, 224, 224)).type(torch.cuda.FloatTensor) target = torch.randn((batch_size, 1, 224, 224)).type(torch.cuda.FloatTensor) # print(vgg.features(noise)) # print(model(noise)) # print(model.seq) # print(vgg.features) # print(F.mse_loss(model.seq[:8](noise), vgg.features[:8](noise))) model = Unet(cfg).cuda() model.encoder.seq.load_state_dict(vgg.features.state_dict()) opt = torch.optim.Adam(model.parameters(), lr=0.001) print('Time: {}'.format(time.clock())) _, loss = model(noise, target) loss.backward() """ for i in range(1000): opt.zero_grad() time_spend = time.clock() _, loss = model(noise, target) print('Time_Spend: {}'.format(time.clock() - time_spend)) loss.backward() opt.step() print(float(loss)) print('Time: {}'.format(time.clock())) """