""" Helpers for cropping image depending on resolution. TODO: replace recursive code with something more adaptive, right now we only do - no crops / 2x2 / 4x4 / 16x16 / etc. and the stitching code is complicated, but would be nice to have e.g. 3x3. """ import math import os import itertools import torch from blueprints.multiscale_blueprint import MultiscaleLoss import functools import operator def prod(it): return functools.reduce(operator.mul, it, 1) # Images with H * W > prod(_NEEDS_CROP_DIM) will be split into crops # We set this empirically such that crops fit into our TITAN X (Pascal) with 12GB VRAM. # You can set this from the console using AC_NEEDS_CROP_DIM, e.g., # # AC_NEEDS_CROP_DIM=2000,2000 python test.py ... # # But expect OOM errors for big values. _NEEDS_CROP_DIM_DEFAULT = '2000,1500' _NEEDS_CROP_DIM = os.environ.get('AC_NEEDS_CROP_DIM', _NEEDS_CROP_DIM_DEFAULT) if _NEEDS_CROP_DIM != _NEEDS_CROP_DIM_DEFAULT: print('*** AC_NEEDS_CROP_DIM =', _NEEDS_CROP_DIM) _NEEDS_CROP_DIM = prod(map(int, _NEEDS_CROP_DIM.split(','))) print('*** AC_NEEDS_CROP_DIM =', _NEEDS_CROP_DIM) def _assert_valid_image(i): if len(i.shape) != 4 or i.shape[1] != 3: raise ValueError(f'Expected BCHW image, got {i.shape}') def needs_crop(img, needs_crop_dim=_NEEDS_CROP_DIM): _assert_valid_image(img) H, W = img.shape[-2:] return H * W > needs_crop_dim def _crop16(im): for im_cropped in _crop4(im): yield from _crop4(im_cropped) def iter_crops(img, needs_crop_dim=_NEEDS_CROP_DIM): _assert_valid_image(img) if not needs_crop(img, needs_crop_dim): yield img return for img_crop in _crop4(img): yield from iter_crops(img_crop, needs_crop_dim) def _crop4(img): _assert_valid_image(img) H, W = img.shape[-2:] imgs = [img[..., :H//2, :W//2], # Top left img[..., :H//2, W//2:], # Top right img[..., H//2:, :W//2], # Bottom left img[..., H//2:, W//2:]] # Bottom right # Validate that we got all pixels assert sum(prod(img.shape[-2:]) for img in imgs) == \ prod(img.shape[-2:]) return imgs def _get_crop_idx_mapping(side): """Helper method to get the order of crops. :param side: how many crops live on each side. Example. Say you have an image that gets devided into 16 crops, i.e., the image gets cut into 16 parts: [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]], However, due to our recursive cropping code, this results in crops that are ordered like this: index of crop: 0 1 2 3 4 ... corresponds to part in image: 0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15 This method returns the inverse, going from the index of the crop back to the index in the image. """ a = torch.arange(side * side).reshape(1, 1, side, side) a = torch.cat((a, a, a), dim=1) # Create mapping # Index of crop in original image -> index of crop in the order it was extracted, # E.g. 2 -> 4 means it's the 2nd crop, but in the image, it's at position 4 (see above). crops = {i: crop[0, 0, ...].flatten().item() for i, crop in enumerate(iter_crops(a, 1))} return crops def stitch(parts): side = int(math.sqrt(len(parts))) if side * side != len(parts): raise ValueError(f'Invalid number of parts {len(parts)}') rows = [] # Sort by original position in image crops_idx_mapping = _get_crop_idx_mapping(side) parts_sorted = ( part for _, part in sorted( enumerate(parts), key=lambda ip: crops_idx_mapping[ip[0]])) parts_itr = iter(parts_sorted) # Turn into iterator so we can easily grab elements for _ in range(side): parts_row = itertools.islice(parts_itr, side) # Get `side` number of parts row = torch.cat(list(parts_row), dim=3) # cat on W dimension rows.append(row) assert next(parts_itr, None) is None, f'Iterator should be empty, got {len(rows)} rows' img = torch.cat(rows, dim=2) # cat on H dimension # Validate. B, C, H_part, W_part = parts[0].shape expected_shape = (B, C, H_part * side, W_part * side) assert img.shape == expected_shape, f'{img.shape} != {expected_shape}' return img class CropLossCombinator(object): """Used to combine the bpsp of different crops into one. Supports crops of varying dimensions.""" def __init__(self): self._num_bits_total = 0. self._num_subpixels_total = 0 def add(self, bpsp, num_subpixels_crop): bits = bpsp * num_subpixels_crop self._num_bits_total += bits self._num_subpixels_total += num_subpixels_crop def get_bpsp(self): assert self._num_subpixels_total > 0 return self._num_bits_total / self._num_subpixels_total def test_auto_crop(): import torch import pytorch_ext as pe for H, W, num_crops_expected in [(10000, 6000, 64), (4928, 3264, 16), (2048, 2048, 4), (1024, 1024, 1), ]: img = (torch.rand(1, 3, H, W) * 255).round().long() print(img.shape) if num_crops_expected > 1: assert needs_crop(img) crops = list(iter_crops(img, 2048 * 1024)) assert len(crops) == num_crops_expected pe.assert_equal(stitch(crops), img) else: pe.assert_equal(next(iter_crops(img, 2048 * 1024)), img)