import re import cPickle as pickle from collections import namedtuple import numpy as np import skimage.transform from scipy.fftpack import fftn, ifftn from skimage.feature import peak_local_max, canny from skimage.transform import hough_circle import skimage.draw from configuration import config import utils import skimage.exposure, skimage.filters def read_labels(file_path): id2labels = {} train_csv = open(file_path) lines = train_csv.readlines() i = 0 for item in lines: if i == 0: i = 1 continue id, systole, diastole = item.replace('\n', '').split(',') id2labels[int(id)] = [float(systole), float(diastole)] return id2labels def read_slice(path): return pickle.load(open(path))['data'] def read_fft_slice(path): d = pickle.load(open(path))['data'] ff1 = fftn(d) fh = np.absolute(ifftn(ff1[1, :, :])) fh[fh < 0.1 * np.max(fh)] = 0.0 d = 1. * fh / np.max(fh) d = np.expand_dims(d, axis=0) return d def read_metadata(path): d = pickle.load(open(path))['metadata'][0] metadata = {k: d[k] for k in ['PixelSpacing', 'ImageOrientationPatient', 'ImagePositionPatient', 'SliceLocation', 'PatientSex', 'PatientAge', 'Rows', 'Columns']} metadata['PixelSpacing'] = np.float32(metadata['PixelSpacing']) metadata['ImageOrientationPatient'] = np.float32(metadata['ImageOrientationPatient']) metadata['SliceLocation'] = np.float32(metadata['SliceLocation']) metadata['ImagePositionPatient'] = np.float32(metadata['ImagePositionPatient']) metadata['PatientSex'] = 1 if metadata['PatientSex'] == 'F' else -1 metadata['PatientAge'] = utils.get_patient_age(metadata['PatientAge']) metadata['Rows'] = int(metadata['Rows']) metadata['Columns'] = int(metadata['Columns']) return metadata def sample_augmentation_parameters(transformation): # TODO: bad thing to mix fixed and random params!!! if set(transformation.keys()) == {'patch_size', 'mm_patch_size'} or \ set(transformation.keys()) == {'patch_size', 'mm_patch_size', 'mask_roi'}: return None shift_x = config().rng.uniform(*transformation.get('translation_range_x', [0., 0.])) shift_y = config().rng.uniform(*transformation.get('translation_range_y', [0., 0.])) translation = (shift_x, shift_y) rotation = config().rng.uniform(*transformation.get('rotation_range', [0., 0.])) shear = config().rng.uniform(*transformation.get('shear_range', [0., 0.])) roi_scale = config().rng.uniform(*transformation.get('roi_scale_range', [1., 1.])) z = config().rng.uniform(*transformation.get('zoom_range', [1., 1.])) zoom = (z, z) if 'do_flip' in transformation: if type(transformation['do_flip']) == tuple: flip_x = config().rng.randint(2) > 0 if transformation['do_flip'][0] else False flip_y = config().rng.randint(2) > 0 if transformation['do_flip'][1] else False else: flip_x = config().rng.randint(2) > 0 if transformation['do_flip'] else False flip_y = False else: flip_x, flip_y = False, False sequence_shift = config().rng.randint(30) if transformation.get('sequence_shift', False) else 0 return namedtuple('Params', ['translation', 'rotation', 'shear', 'zoom', 'roi_scale', 'flip_x', 'flip_y', 'sequence_shift'])(translation, rotation, shear, zoom, roi_scale, flip_x, flip_y, sequence_shift) def transform_norm_rescale(data, metadata, transformation, roi=None, random_augmentation_params=None, mm_center_location=(.5, .4), mm_patch_size=(128, 128), mask_roi=True): patch_size = transformation['patch_size'] mm_patch_size = transformation.get('mm_patch_size', mm_patch_size) mask_roi = transformation.get('mask_roi', mask_roi) out_shape = (30,) + patch_size out_data = np.zeros(out_shape, dtype='float32') roi_center = roi['roi_center'] if roi else None roi_radii = roi['roi_radii'] if roi else None # correct orientation data, roi_center, roi_radii = correct_orientation(data, metadata, roi_center, roi_radii) # if random_augmentation_params=None -> sample new params # if the transformation implies no augmentations then random_augmentation_params remains None if not random_augmentation_params: random_augmentation_params = sample_augmentation_parameters(transformation) # build scaling transformation pixel_spacing = metadata['PixelSpacing'] assert pixel_spacing[0] == pixel_spacing[1] current_shape = data.shape[-2:] # scale ROI radii and find ROI center in normalized patch if roi_center: mm_center_location = tuple(int(r * ps) for r, ps in zip(roi_center, pixel_spacing)) # scale the images such that they all have the same scale norm_rescaling = 1. / pixel_spacing[0] mm_shape = tuple(int(float(d) * ps) for d, ps in zip(current_shape, pixel_spacing)) tform_normscale = build_rescale_transform(downscale_factor=norm_rescaling, image_shape=current_shape, target_shape=mm_shape) tform_shift_center, tform_shift_uncenter = build_shift_center_transform(image_shape=mm_shape, center_location=mm_center_location, patch_size=mm_patch_size) patch_scale = max(1. * mm_patch_size[0] / patch_size[0], 1. * mm_patch_size[1] / patch_size[1]) tform_patch_scale = build_rescale_transform(patch_scale, mm_patch_size, target_shape=patch_size) total_tform = tform_patch_scale + tform_shift_uncenter + tform_shift_center + tform_normscale # build random augmentation if random_augmentation_params: augment_tform = build_augmentation_transform(rotation=random_augmentation_params.rotation, shear=random_augmentation_params.shear, translation=random_augmentation_params.translation, flip_x=random_augmentation_params.flip_x, flip_y=random_augmentation_params.flip_y, zoom=random_augmentation_params.zoom) total_tform = tform_patch_scale + tform_shift_uncenter + augment_tform + tform_shift_center + tform_normscale # apply transformation per image for i in xrange(data.shape[0]): out_data[i] = fast_warp(data[i], total_tform, output_shape=patch_size) normalize_contrast_zmuv(out_data) # apply transformation to ROI and mask the images if roi_center and roi_radii and mask_roi: roi_scale = random_augmentation_params.roi_scale if random_augmentation_params else 1 # augmentation roi_zoom = random_augmentation_params.zoom if random_augmentation_params else (1., 1.) rescaled_roi_radii = (roi_scale * roi_radii[0], roi_scale * roi_radii[1]) out_roi_radii = (int(roi_zoom[0] * rescaled_roi_radii[0] * pixel_spacing[0] / patch_scale), int(roi_zoom[1] * rescaled_roi_radii[1] * pixel_spacing[1] / patch_scale)) roi_mask = make_circular_roi_mask(patch_size, (patch_size[0] / 2, patch_size[1] / 2), out_roi_radii) out_data *= roi_mask # if the sequence is < 30 timesteps, copy last image if data.shape[0] < out_shape[0]: for j in xrange(data.shape[0], out_shape[0]): out_data[j] = out_data[j - 1] # if > 30, remove images if data.shape[0] > out_shape[0]: out_data = out_data[:30] # shift the sequence for a number of time steps if random_augmentation_params: out_data = np.roll(out_data, random_augmentation_params.sequence_shift, axis=0) if random_augmentation_params: targets_zoom_factor = random_augmentation_params.zoom[0] * random_augmentation_params.zoom[1] else: targets_zoom_factor = 1. return out_data, targets_zoom_factor def transform_norm_rescale_after(data, metadata, transformation, roi=None, random_augmentation_params=None, mm_center_location=(.5, .4), mm_patch_size=(128, 128), mask_roi=True): patch_size = transformation['patch_size'] mm_patch_size = transformation.get('mm_patch_size', mm_patch_size) mask_roi = transformation.get('mask_roi', mask_roi) out_shape = (30,) + patch_size out_data = np.zeros(out_shape, dtype='float32') roi_center = roi['roi_center'] if roi else None roi_radii = roi['roi_radii'] if roi else None # correct orientation data, roi_center, roi_radii = correct_orientation(data, metadata, roi_center, roi_radii) # if random_augmentation_params=None -> sample new params # if the transformation implies no augmentations then random_augmentation_params remains None if not random_augmentation_params: random_augmentation_params = sample_augmentation_parameters(transformation) # build scaling transformation pixel_spacing = metadata['PixelSpacing'] assert pixel_spacing[0] == pixel_spacing[1] current_shape = data.shape[-2:] # scale ROI radii and find ROI center in normalized patch if roi_center: mm_center_location = tuple(int(r * ps) for r, ps in zip(roi_center, pixel_spacing)) # scale the images such that they all have the same scale norm_rescaling = 1. / pixel_spacing[0] mm_shape = tuple(int(float(d) * ps) for d, ps in zip(current_shape, pixel_spacing)) tform_normscale = build_rescale_transform(downscale_factor=norm_rescaling, image_shape=current_shape, target_shape=mm_shape) tform_shift_center, tform_shift_uncenter = build_shift_center_transform(image_shape=mm_shape, center_location=mm_center_location, patch_size=mm_patch_size) patch_scale = max(1. * mm_patch_size[0] / patch_size[0], 1. * mm_patch_size[1] / patch_size[1]) tform_patch_scale = build_rescale_transform(patch_scale, mm_patch_size, target_shape=patch_size) total_tform = tform_patch_scale + tform_shift_uncenter + tform_shift_center + tform_normscale # build random augmentation if random_augmentation_params: augment_tform = build_augmentation_transform(rotation=random_augmentation_params.rotation, shear=random_augmentation_params.shear, translation=random_augmentation_params.translation, flip_x=random_augmentation_params.flip_x, flip_y=random_augmentation_params.flip_y, zoom=random_augmentation_params.zoom) total_tform = tform_patch_scale + tform_shift_uncenter + augment_tform + tform_shift_center + tform_normscale # apply transformation per image for i in xrange(data.shape[0]): out_data[i] = fast_warp(data[i], total_tform, output_shape=patch_size) # apply transformation to ROI and mask the images if roi_center and roi_radii and mask_roi: roi_scale = random_augmentation_params.roi_scale if random_augmentation_params else 1 # augmentation roi_zoom = random_augmentation_params.zoom if random_augmentation_params else (1., 1.) rescaled_roi_radii = (roi_scale * roi_radii[0], roi_scale * roi_radii[1]) out_roi_radii = (int(roi_zoom[0] * rescaled_roi_radii[0] * pixel_spacing[0] / patch_scale), int(roi_zoom[1] * rescaled_roi_radii[1] * pixel_spacing[1] / patch_scale)) roi_mask = make_circular_roi_mask(patch_size, (patch_size[0] / 2, patch_size[1] / 2), out_roi_radii) out_data *= roi_mask normalize_contrast_zmuv(out_data) # if the sequence is < 30 timesteps, copy last image if data.shape[0] < out_shape[0]: for j in xrange(data.shape[0], out_shape[0]): out_data[j] = out_data[j - 1] # if > 30, remove images if data.shape[0] > out_shape[0]: out_data = out_data[:30] # shift the sequence for a number of time steps if random_augmentation_params: out_data = np.roll(out_data, random_augmentation_params.sequence_shift, axis=0) if random_augmentation_params: targets_zoom_factor = random_augmentation_params.zoom[0] * random_augmentation_params.zoom[1] else: targets_zoom_factor = 1. return out_data, targets_zoom_factor def make_roi_mask(img_shape, roi_center, roi_radii): """ Makes 2D ROI mask for one slice :param data: :param roi: :return: """ mask = np.zeros(img_shape) mask[max(0, roi_center[0] - roi_radii[0]):min(roi_center[0] + roi_radii[0], img_shape[0]), max(0, roi_center[1] - roi_radii[1]):min(roi_center[1] + roi_radii[1], img_shape[1])] = 1 return mask def make_circular_roi_mask(img_shape, roi_center, roi_radii): mask = np.zeros(img_shape) rr, cc = skimage.draw.ellipse(roi_center[0], roi_center[1], roi_radii[0], roi_radii[1], img_shape) mask[rr, cc] = 1. return mask tform_identity = skimage.transform.AffineTransform() def fast_warp(img, tf, output_shape, mode='constant', order=1): """ This wrapper function is faster than skimage.transform.warp """ m = tf.params # tf._matrix is return skimage.transform._warps_cy._warp_fast(img, m, output_shape=output_shape, mode=mode, order=order) def build_centering_transform(image_shape, target_shape=(50, 50)): rows, cols = image_shape trows, tcols = target_shape shift_x = (cols - tcols) / 2.0 shift_y = (rows - trows) / 2.0 return skimage.transform.SimilarityTransform(translation=(shift_x, shift_y)) def build_rescale_transform(downscale_factor, image_shape, target_shape): """ estimating the correct rescaling transform is slow, so just use the downscale_factor to define a transform directly. This probably isn't 100% correct, but it shouldn't matter much in practice. """ rows, cols = image_shape trows, tcols = target_shape tform_ds = skimage.transform.AffineTransform(scale=(downscale_factor, downscale_factor)) # centering shift_x = cols / (2.0 * downscale_factor) - tcols / 2.0 shift_y = rows / (2.0 * downscale_factor) - trows / 2.0 tform_shift_ds = skimage.transform.SimilarityTransform(translation=(shift_x, shift_y)) return tform_shift_ds + tform_ds def build_center_uncenter_transforms(image_shape): """ These are used to ensure that zooming and rotation happens around the center of the image. Use these transforms to center and uncenter the image around such a transform. """ center_shift = np.array( [image_shape[1], image_shape[0]]) / 2.0 - 0.5 # need to swap rows and cols here apparently! confusing! tform_uncenter = skimage.transform.SimilarityTransform(translation=-center_shift) tform_center = skimage.transform.SimilarityTransform(translation=center_shift) return tform_center, tform_uncenter def build_augmentation_transform(rotation=0, shear=0, translation=(0, 0), flip_x=False, flip_y=False, zoom=(1.0, 1.0)): if flip_x: shear += 180 # shear by 180 degrees is equivalent to flip along the X-axis if flip_y: shear += 180 rotation += 180 tform_augment = skimage.transform.AffineTransform(scale=(1. / zoom[0], 1. / zoom[1]), rotation=np.deg2rad(rotation), shear=np.deg2rad(shear), translation=translation) return tform_augment def correct_orientation(data, metadata, roi_center, roi_radii): F = metadata["ImageOrientationPatient"].reshape((2, 3)) f_1 = F[1, :] f_2 = F[0, :] y_e = np.array([0, 1, 0]) if abs(np.dot(y_e, f_1)) >= abs(np.dot(y_e, f_2)): out_data = np.transpose(data, (0, 2, 1)) out_roi_center = (roi_center[1], roi_center[0]) if roi_center else None out_roi_radii = (roi_radii[1], roi_radii[0]) if roi_radii else None else: out_data = data out_roi_center = roi_center out_roi_radii = roi_radii return out_data, out_roi_center, out_roi_radii def build_shift_center_transform(image_shape, center_location, patch_size): """Shifts the center of the image to a given location. This function tries to include as much as possible of the image in the patch centered around the new center. If the patch arount the ideal center location doesn't fit within the image, we shift the center to the right so that it does. params in (i,j) coordinates !!! """ if center_location[0] < 1. and center_location[1] < 1.: center_absolute_location = [ center_location[0] * image_shape[0], center_location[1] * image_shape[1]] else: center_absolute_location = [center_location[0], center_location[1]] # Check for overlap at the edges center_absolute_location[0] = max( center_absolute_location[0], patch_size[0] / 2.0) center_absolute_location[1] = max( center_absolute_location[1], patch_size[1] / 2.0) center_absolute_location[0] = min( center_absolute_location[0], image_shape[0] - patch_size[0] / 2.0) center_absolute_location[1] = min( center_absolute_location[1], image_shape[1] - patch_size[1] / 2.0) # Check for overlap at both edges if patch_size[0] > image_shape[0]: center_absolute_location[0] = image_shape[0] / 2.0 if patch_size[1] > image_shape[1]: center_absolute_location[1] = image_shape[1] / 2.0 # Build transform new_center = np.array(center_absolute_location) translation_center = new_center - 0.5 translation_uncenter = -np.array((patch_size[0] / 2.0, patch_size[1] / 2.0)) - 0.5 return ( skimage.transform.SimilarityTransform(translation=translation_center[::-1]), skimage.transform.SimilarityTransform(translation=translation_uncenter[::-1])) def normalize_contrast_zmuv(data, z=2): mean = np.mean(data) std = np.std(data) for i in xrange(len(data)): img = data[i] img = ((img - mean) / (2 * std * z) + 0.5) data[i] = np.clip(img, -0.0, 1.0) def slice_location_finder(slicepath2metadata): """ :param slicepath2metadata: dict with arbitrary keys, and metadata values :return: """ slicepath2midpix = {} slicepath2position = {} for sp, metadata in slicepath2metadata.iteritems(): if 'sax' in sp: image_orientation = metadata["ImageOrientationPatient"] image_position = metadata["ImagePositionPatient"] pixel_spacing = metadata["PixelSpacing"] rows = metadata['Rows'] columns = metadata['Columns'] # calculate value of middle pixel F = np.array(image_orientation).reshape((2, 3)) # reversed order, as per http://nipy.org/nibabel/dicom/dicom_orientation.html i, j = columns / 2.0, rows / 2.0 im_pos = np.array([[i * pixel_spacing[0], j * pixel_spacing[1]]], dtype='float32') pos = np.array(image_position).reshape((1, 3)) position = np.dot(im_pos, F) + pos slicepath2midpix[sp] = position[0, :] if '2ch' in sp: slicepath2position[sp] = -10 if '4ch' in sp: slicepath2position[sp] = -50 if len(slicepath2midpix) <= 1: for sp, midpix in slicepath2midpix.iteritems(): slicepath2position[sp] = 0. else: # find the keys of the 2 points furthest away from each other max_dist = -1.0 max_dist_keys = [] for sp1, midpix1 in slicepath2midpix.iteritems(): for sp2, midpix2 in slicepath2midpix.iteritems(): if sp1 == sp2: continue distance = np.sqrt(np.sum((midpix1 - midpix2) ** 2)) if distance > max_dist: max_dist_keys = [sp1, sp2] max_dist = distance # project the others on the line between these 2 points # sort the keys, so the order is more or less the same as they were max_dist_keys.sort(key=lambda x: int(re.search(r'/sax_(\d+)\.pkl$', x).group(1))) p_ref1 = slicepath2midpix[max_dist_keys[0]] p_ref2 = slicepath2midpix[max_dist_keys[1]] v1 = p_ref2 - p_ref1 v1 /= np.linalg.norm(v1) for sp, midpix in slicepath2midpix.iteritems(): v2 = midpix - p_ref1 slicepath2position[sp] = np.inner(v1, v2) return slicepath2position def extract_roi(data, pixel_spacing, minradius_mm=25, maxradius_mm=45, kernel_width=5, center_margin=8, num_peaks=10, num_circles=20, radstep=2): """ Returns center and radii of ROI region in (i,j) format """ # radius of the smallest and largest circles in mm estimated from the train set # convert to pixel counts minradius = int(minradius_mm / pixel_spacing) maxradius = int(maxradius_mm / pixel_spacing) ximagesize = data[0]['data'].shape[1] yimagesize = data[0]['data'].shape[2] xsurface = np.tile(range(ximagesize), (yimagesize, 1)).T ysurface = np.tile(range(yimagesize), (ximagesize, 1)) lsurface = np.zeros((ximagesize, yimagesize)) allcenters = [] allaccums = [] allradii = [] for dslice in data: ff1 = fftn(dslice['data']) fh = np.absolute(ifftn(ff1[1, :, :])) fh[fh < 0.1 * np.max(fh)] = 0.0 image = 1. * fh / np.max(fh) # find hough circles and detect two radii edges = canny(image, sigma=3) hough_radii = np.arange(minradius, maxradius, radstep) hough_res = hough_circle(edges, hough_radii) if hough_res.any(): centers = [] accums = [] radii = [] for radius, h in zip(hough_radii, hough_res): # For each radius, extract num_peaks circles peaks = peak_local_max(h, num_peaks=num_peaks) centers.extend(peaks) accums.extend(h[peaks[:, 0], peaks[:, 1]]) radii.extend([radius] * num_peaks) # Keep the most prominent num_circles circles sorted_circles_idxs = np.argsort(accums)[::-1][:num_circles] for idx in sorted_circles_idxs: center_x, center_y = centers[idx] allcenters.append(centers[idx]) allradii.append(radii[idx]) allaccums.append(accums[idx]) brightness = accums[idx] lsurface = lsurface + brightness * np.exp( -((xsurface - center_x) ** 2 + (ysurface - center_y) ** 2) / kernel_width ** 2) lsurface = lsurface / lsurface.max() # select most likely ROI center roi_center = np.unravel_index(lsurface.argmax(), lsurface.shape) # determine ROI radius roi_x_radius = 0 roi_y_radius = 0 for idx in range(len(allcenters)): xshift = np.abs(allcenters[idx][0] - roi_center[0]) yshift = np.abs(allcenters[idx][1] - roi_center[1]) if (xshift <= center_margin) & (yshift <= center_margin): roi_x_radius = np.max((roi_x_radius, allradii[idx] + xshift)) roi_y_radius = np.max((roi_y_radius, allradii[idx] + yshift)) if roi_x_radius > 0 and roi_y_radius > 0: roi_radii = roi_x_radius, roi_y_radius else: roi_radii = None return roi_center, roi_radii