import SimpleITK as sitk import numpy as np import csv import os, glob from PIL import Image import matplotlib.pyplot as plt import pickle import scipy.ndimage from skimage import measure, morphology from mpl_toolkits.mplot3d.art3d import Poly3DCollection from time import time OUTPUT_SPACING = [1.25, 1.25, 1.25] def load_itk_image(filename): itkimage = sitk.ReadImage(filename) numpyImage = sitk.GetArrayFromImage(itkimage) numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) return numpyImage, numpyOrigin, numpySpacing def read_csv(filename): lines = [] with open(filename, 'r') as f: csvreader = csv.reader(f) for line in csvreader: lines.append(line) lines = lines[1:] # remove csv headers annotations_dict = {} for i in lines: series_uid, x, y, z, diameter = i value = {'position':[float(x),float(y),float(z)], 'diameter':float(diameter)} if series_uid in annotations_dict.keys(): annotations_dict[series_uid].append(value) else: annotations_dict[series_uid] = [value] return annotations_dict def compute_coord(worldCoord, origin, spacing): stretchedVoxelCoord = np.absolute(worldCoord - origin) voxelCoord = stretchedVoxelCoord / spacing return voxelCoord def normalize_planes(npzarray): #maxHU = 600. #minHU = -1200. maxHU = 400. minHU = -1000. npzarray = (npzarray - minHU) / (maxHU - minHU) npzarray[npzarray>1] = 1. npzarray[npzarray<0] = 0. return npzarray def zero_center(image): PIXEL_MEAN = 0.25 image = image - PIXEL_MEAN return image def resample(image, org_spacing, new_spacing=OUTPUT_SPACING): resize_factor = org_spacing / new_spacing new_real_shape = image.shape * resize_factor new_shape = np.round(new_real_shape) real_resize_factor = new_shape / image.shape new_spacing = org_spacing / real_resize_factor image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest') return image, new_spacing def create_label(arr_shape, nodules, new_spacing, coord=False): """ nodules = list of dict {'position', 'diameter'} """ def _create_mask(arr_shape, position, diameter): z_dim, y_dim, x_dim = arr_shape z_pos, y_pos, x_pos = position z,y,x = np.ogrid[-z_pos:z_dim-z_pos, -y_pos:y_dim-y_pos, -x_pos:x_dim-x_pos] mask = z**2 + y**2 + x**2 <= int(diameter//2)**2 return mask if coord: label = [] else: label = np.zeros(arr_shape, dtype='bool') for nodule in nodules: worldCoord = nodule['position'] worldCoord = np.asarray([worldCoord[2],worldCoord[1],worldCoord[0]]) # new_spacing came from resample voxelCoord = compute_coord(worldCoord, origin, new_spacing) voxelCoord = [int(i) for i in voxelCoord] diameter = nodule['diameter'] diameter = diameter / new_spacing[1] if coord: label.append(voxelCoord + [diameter]) else: mask = _create_mask(arr_shape, voxelCoord, diameter) label = np.logical_or(label, mask) return label def plot(image, label, z_idx): fig = plt.figure() ax = fig.add_subplot(1,2,1) ax.imshow(image[z_idx,:,:],cmap='gray') ax = fig.add_subplot(1,2,2) ax.imshow(label[z_idx,:,:],cmap='gray') fig.show() if __name__=="__main__": from matplotlib.patches import Circle coord = True exclude_flag = True dst_spacing = OUTPUT_SPACING #src_root = '/lunit/data/LUNA16/rawdata' dst_root = '/data2/jhkim/npydata' src_root = '/data/jhkim/LUNA16/original' if not os.path.exists(dst_root): os.makedirs(dst_root) if exclude_flag : annotation_csv = os.path.join(src_root,'CSVFILES/annotations_excluded.csv') else : annotation_csv = os.path.join(src_root,'CSVFILES/annotations.csv') annotations = read_csv(annotation_csv) src_mhd = [] coord_dict = {} all_file = len(glob.glob(os.path.join(src_root,'subset[0-9]','*.mhd'))) cnt = 1 for i in glob.glob(os.path.join(src_root,'subset[0-9]','*.mhd')): st = time() filename = os.path.split(i)[-1] series_uid = os.path.splitext(filename)[0] subset_num = i.split('/')[-2] dst_subset_path = os.path.join(dst_root,subset_num) if not os.path.exists(dst_subset_path): os.makedirs(dst_subset_path) np_img, origin, spacing = load_itk_image(i) resampled_img, new_spacing = resample(np_img, spacing, dst_spacing) resampled_img_shape = resampled_img.shape norm_img = normalize_planes(resampled_img) norm_img = zero_center(norm_img) try: nodules = annotations[series_uid] label = create_label(resampled_img_shape, nodules, new_spacing, coord=coord) except: if coord: label = [] else: label = np.zeros(resampled_img_shape, dtype='bool') # np.save(os.path.join(dst_subset_path,series_uid+'.npy'), norm_img) # np.save(os.path.join(dst_subset_path,series_uid+'.label.npy'), label) coord_dict[series_uid] = label with open('exclude.pkl', 'wb') as f: pickle.dump(coord_dict, f, protocol=pickle.HIGHEST_PROTOCOL) print('{} / {} / {}'.format(cnt, all_file, time() - st)) cnt += 1