import time import torch import torch.nn as nn import numpy as np import cv2 import torchvision.utils as vutils # visualization import matplotlib.pyplot as plt # visualization class Infer(object): def __init__(self, model, nps=500, scales=(1.,), meanstd={'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}, iSz=160, device='cpu', timer=True): self.trunk = model.trunk self.mHead = model.maskBranch self.sHead = model.scoreBranch self.nps = nps self.mean = torch.from_numpy(np.array(meanstd['mean']).astype(np.float32)).view(1, 3, 1, 1).to(device) self.std = torch.from_numpy(np.array(meanstd['std']).astype(np.float32)).view(1, 3, 1, 1).to(device) self.iSz, self.bw = iSz, iSz // 2 self.device = device self.timer = np.zeros(6) self.display_time = timer self.scales = scales self.pyramid = [nn.UpsamplingBilinear2d(scale_factor=s).to(device) for s in self.scales] def forward(self, img): tic = time.time() imgPyramid = [pyramid(img) for pyramid in self.pyramid] self.timer[0] = time.time() - tic self.mask, self.score = [], [] for inp in imgPyramid: tic = time.time() imgPad = nn.ConstantPad2d(self.bw, 0.5).cuda()(inp) # cv2.imshow('pad image', np.transpose(imgPad.squeeze().cpu().data.numpy(), axes=(1, 2, 0))[:,::-1]) # cv2.waitKey(0) imgPad = imgPad.sub_(self.mean).div_(self.std) self.timer[1] += time.time() - tic tic = time.time() outTrunk = self.trunk(imgPad) self.timer[2] += time.time() - tic tic = time.time() outMask = self.mHead(outTrunk) self.timer[3] += time.time() - tic tic = time.time() outScore = self.sHead(outTrunk) self.timer[4] += time.time() - tic # mask_show = vutils.make_grid(outMask.sigmoid().transpose(0, 1), nrow=outScore.shape[-1], pad_value=0) # mask_show_numpy = np.transpose(mask_show.cpu().data.numpy(), axes=(1, 2, 0)) # plt.imshow(mask_show_numpy[:,:,0], cmap='jet') # plt.show() self.mask.append(outMask.sigmoid().cpu().data.numpy()) self.score.append(outScore.sigmoid().cpu().data.numpy()) def getTopScores(self): li = [] for i, sc in enumerate(self.score): h, w = sc.shape[2:] sc = sc.flatten() sIds, sS = zip(*sorted(enumerate(sc), key=lambda a: a[1], reverse=True)) for ss, sid in zip(sS, sIds): li.append([ss, i, sid, sid//w, sid % w]) li = sorted(li, key=lambda l: l[0], reverse=True) topScores = li[:self.nps] self.topScores = topScores def crop_hwc(self, image, bbox, out_sz, padding=0): a = (out_sz[0]-1) / bbox[2] b = (out_sz[1]-1) / bbox[3] c = -a * bbox[0] d = -b * bbox[1] mapping = np.array([[a, 0, c], [0, b, d]]).astype(np.float) crop = cv2.warpAffine(image, mapping, (out_sz[0], out_sz[1]), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) return crop def getTopMasks(self, thr, h, w): masks, ts, nps = self.mask, self.topScores, self.nps # thr = np.log(thr / (1 - thr)) topMasks = np.zeros((h, w, nps), dtype=np.uint8) topScores = np.zeros(nps) for i in range(nps): scale, sid, x, y = ts[i][1:] s = self.scales[scale] mask = masks[scale][0, sid] mask = cv2.resize(mask, (self.iSz, self.iSz)) # cv2.imshow('mask', mask) # cv2.waitKey(0) imgMask = self.crop_hwc(mask, (self.bw - 16*y, self.bw - 16*x, w*s, h*s), (w, h)) imgMask = imgMask > thr topMasks[:, :, i] = imgMask.copy() topScores[i] = ts[i][0] return topMasks, topScores def getTopProps(self, thr, h, w): tic = time.time() self.getTopScores() topMasks, topScores = self.getTopMasks(thr, h, w) self.timer[5] = time.time() - tic if self.display_time: self.printTiming() return topMasks, topScores def printTiming(self): t = self.timer print('| time pyramid: {:.1f} ms'.format(t[0]*1000)) print('| time pre-process: {:.1f} ms'.format(t[1]*1000)) print('| time trunk: {:.1f} ms'.format(t[2]*1000)) print('| time mask branch: {:.1f} ms'.format(t[3]*1000)) print('| time score branch: {:.1f} ms'.format(t[4]*1000)) print('| time post processing: {:.1f} ms'.format(t[5]*1000)) print('| time total: {:.1f} ms'.format(sum(t)*1000))