import torch import torch.nn as nn import numpy as np import scipy.misc import cv2 from .misc import * def im_to_numpy(img): img = to_numpy(img) img = np.transpose(img, (1, 2, 0)) # H*W*C return img def im_to_torch(img): img = np.transpose(img, (2, 0, 1)) # C*H*W img = to_torch(img).float() if img.max() > 1: img /= 255 return img def load_image(img_path): # H x W x C => C x H x W return im_to_torch(scipy.misc.imread(img_path, mode='RGB')) def resize(img, owidth, oheight): img = im_to_numpy(img) print('%f %f' % (img.min(), img.max())) img = scipy.misc.imresize( img, (oheight, owidth) ) img = im_to_torch(img) print('%f %f' % (img.min(), img.max())) return img def generate_heatmap(heatmap, pt, sigma): heatmap[int(pt[1])][int(pt[0])] = 1 heatmap = cv2.GaussianBlur(heatmap, sigma, 0) am = np.amax(heatmap) heatmap /= am / 255 return heatmap # ============================================================================= # Helpful display functions # ============================================================================= def gauss(x, a, b, c, d=0): return a * np.exp(-(x - b)**2 / (2 * c**2)) + d def color_heatmap(x): x = to_numpy(x) color = np.zeros((x.shape[0],x.shape[1],3)) color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) color[:,:,1] = gauss(x, 1, .5, .3) color[:,:,2] = gauss(x, 1, .2, .3) color[color > 1] = 1 color = (color * 255).astype(np.uint8) return color def imshow(img): npimg = im_to_numpy(img*255).astype(np.uint8) plt.imshow(npimg) plt.axis('off') def show_joints(img, pts): imshow(img) for i in range(pts.size(0)): if pts[i, 2] > 0: plt.plot(pts[i, 0], pts[i, 1], 'yo') plt.axis('off') def show_sample(inputs, target): num_sample = inputs.size(0) num_joints = target.size(1) height = target.size(2) width = target.size(3) for n in range(num_sample): inp = resize(inputs[n], width, height) out = inp for p in range(num_joints): tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5 out = torch.cat((out, tgt), 2) imshow(out) plt.show() def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None): inp = to_numpy(inp * 255) out = to_numpy(out) img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0])) for i in range(3): img[:, :, i] = inp[i, :, :] if parts_to_show is None: parts_to_show = np.arange(out.shape[0]) # Generate a single image to display input/output pair num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows)) size = img.shape[0] // num_rows full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) full_img[:img.shape[0], :img.shape[1]] = img inp_small = scipy.misc.imresize(img, [size, size]) # Set up heatmap display for each part for i, part in enumerate(parts_to_show): part_idx = part out_resized = scipy.misc.imresize(out[part_idx], [size, size]) out_resized = out_resized.astype(float)/255 out_img = inp_small.copy() * .3 color_hm = color_heatmap(out_resized) out_img += color_hm * .7 col_offset = (i % num_cols + num_rows) * size row_offset = (i // num_cols) * size full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img return full_img def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None): batch_img = [] for n in range(min(inputs.size(0), 4)): inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n]) batch_img.append( sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show) ) return np.concatenate(batch_img)