#!/usr/bin/env python
import os
import shutil
import math
from glob import glob
import cv2
import csv
import random
from PIL import Image, ImageDraw
import dicom
import copy
import numpy as np
import SimpleITK as itk
from skimage import measure
import logging
import cPickle as pickle

# configuration options

DICOM_STRICT = False
SPACING = 0.8
GAP = 5
FAST = 400

if 'SPACING' in os.environ:
    SPACING = float(os.environ['SPACING'])
    print 'OVERRIDING SPACING = %f' % SPACING

if 'GAP' in os.environ:
    GAP = int(os.environ['GAP'])
    print 'OVERRIDING GAP = %d' % GAP

def get3c (images, i):
    if i < GAP:
        return None
    if i + GAP >= images.shape[0]:
        return None
    a = images[i-GAP]
    b = images[i]
    c = images[i+GAP]
    c3 = np.zeros(a.shape + (3,), dtype=np.float32)
    c3[:,:,0] = a
    c3[:,:,1] = b
    c3[:,:,2] = c
    return c3


#######################

def trim_loc (array1d, margin=0):
    w = np.where(array1d > 0)
    x0 = np.min(w)
    x1 = np.max(w)+1
    return max(0, x0-margin), min(x1+margin, array1d.shape[0])

def try_mkdir (path):
    try:
        os.makedirs(path)
    except:
        pass

def try_remove (path):
    try:
        os.remove(path)
    except:
        shutil.rmtree(path, ignore_errors=True)
    pass

def chunks (l, n):
    for i in range(0, len(l), n):
        yield l[i:(i+n)]

ROOT = os.path.abspath(os.path.dirname(__file__))
DATA_DIR = os.path.join(ROOT, 'data')

def dicom_error (dcm, msg, level=logging.ERROR):
    s = 'DICOM ERROR (%s): %s' % (dcm.filename, msg)
    if DICOM_STRICT and level >= logging.ERROR:
        raise Exception(s)
    else:
        logging.log(level, s)
    pass

# stores example uid & label information
# Stage.train = [(uid, label)]
# Stage.test = [(uid, 0.5)]
# Stage.examples = [(uid, 0.5)]

def dcm_sanity_check (dcm):
    rx, ry, rz, cx, cy, cz = [float(v) for v in dcm.ImageOrientationPatient]
    pass

class DICOM:
    def __init__ (self, dcm):
        self.patient_id = dcm.PatientID
        self.study_id = dcm.StudyInstanceUID
        self.series_id = dcm.SeriesInstanceUID
        self.HU = (float(dcm.RescaleSlope), float(dcm.RescaleIntercept))
        # filename as slice ID
        self.sid = os.path.splitext(os.path.basename(dcm.filename))[0]
        self.dcm = dcm
        self.image = dcm.pixel_array
        self.shape = dcm.pixel_array.shape
        self.pixel_padding = None
        try:
            self.pixel_padding = int(dcm.PixelPaddingValue)
        except:
            pass
        from dicom.tag import Tag
        #tag = Tag(0x0020,0x0032)
        #print dcm[tag].value
        #print dcm.ImagePositionPatient
        #assert dcm[tag] == dcm.ImagePositionPatient
        x, y, z = [float(v) for v in dcm.ImagePositionPatient]
        self.position = (x, y, z)
        rx, ry, rz, cx, cy, cz = [float(v) for v in dcm.ImageOrientationPatient]
        self.ori_row = (rx, ry, rz)
        self.ori_col = (cx, cy, cz)
        x, y = [float(v) for v in dcm.PixelSpacing]
        assert x == y
        self.spacing = x
        # Stage1: 4704 missing SliceLocation
        try:
            self.location = float(dcm.SliceLocation)
        except:
            dicom_error(dcm, 'Missing SliceLocation', level=logging.DEBUG)
            self.location = self.position[2]
            pass
        self.bits = dcm.BitsStored


        if False:
            # Non have SliceThickness
            tag = Tag(0x0018, 0x0050)
            if not tag in dcm:
                dicom_error(dcm, 'Missing SliceThickness', level=logging.WARN)
            else:
                logging.info('Has SliceThickness: %s' % dcm.filename)
                self.thickness = float(dcm[tag].value)

        # ???, why the value is as big as 63536
        if False:
            # Stage1 data:
            # 4704 have padding, 126057 not, so skip this
            self.padding = None
            try:
                self.padding = dcm.PixelPaddingValue
            except:
                dicom_error(dcm, 'Missing PixelPaddingValue', level=logging.WARN)
                pass

        # sanity check
        #if dcm.PatientName != dcm.PatientID:
        #    dicom_error(dcm, 'PatientName is not dcm.PatientID')
        if dcm.Modality != 'CT':
            dicom_error(dcm, 'Bad Modality: ' + dcm.Modality)
        #if Tag(0x0008,0x103e) in dcm:
        if False:
            if dcm.SeriesDescription != 'Axial' and dcm.SeriesDescription != 'mediastinal_lymph_nodes' and dcm.SeriesDescription != 'Recon 2: ACRIN LARGE' and dcm.SeriesDescription != 'Recon 3: CHEST-ABD':
                dicom_error(dcm, 'Bad SeriesDescription: ' + dcm.SeriesDescription)
        #if Tag(0x0008,0x0008) in dcm:
        #    if not 'AXIAL' in ' '.join(list(dcm.ImageType)).upper():
        #        dicom_error(dcm, 'Bad image type: ' + list(dcm.ImageType))

        ori_type_tag = Tag(0x0010,0x2210)
        if ori_type_tag in dcm:
            ori_type = dcm[ori_type_tag].value
            if 'BIPED' != ori_type:
                dicom_error(dcm, 'Bad Anatomical Orientation Type: ' + ori_type)
        # location should roughly be position.z
        self.funny_slice_location = abs(self.position[2] - self.location) > 10

        x, y, z = self.ori_row  # should be (1, 0, 0)
        if x < 0.9:
            dicom_error(dcm, 'Bad row orientation')
        x, y, z = self.ori_col  # should be (0, 1, 0)
        if y < 0.9:
            dicom_error(dcm, 'Bad col orientation')
        pass
    pass

def segment_lung_axial (image, th=123.85, dilate=0.01):

    blur = np.copy(image)
    for i in range(blur.shape[0]):
        cv2.blur(blur[i], (5,5), blur[i])

    binary = np.array(blur < th, dtype=np.uint8)
    # 0: body
    # 1: air & lung
    labels = measure.label(binary, background=-1)

    # set air (same cc as corners) -> body
    bg_labels = set()
    for z in [0, -1]:
        for y in [0, -1]:
            for x in [0, -1]:
                bg_labels.add(labels[z, y, x])
    bg_labels = list(bg_labels)
    print(bg_labels)
    if len(bg_labels) > 1:
        logging.warn('bg not connected, detected %d components' % len(bg_labels))
        pass
    for bg_label in bg_labels:
        binary[bg_label == labels] = 0
        pass


    # now binary:
    #   0: non-lung & body tissue in lung
    #   1: lung & holes in body
    for i, sl in enumerate(binary):
        #H, W = sl.shape
        ll = measure.label(sl, background=-1)   # connected components
        # biggest CC should be body
        vv, cc = np.unique(ll, return_counts=True) 
        assert len(vv) > 0
        body_ll = vv[np.argmax(cc)]
        binary[i][ll != body_ll] = 1
        pass

    # set corner again
    labels = measure.label(binary, background=0)
    bg_labels = set([0])
    for z in [0, -1]:
        for y in [0, -1]:
            for x in [0, -1]:
                bg_labels.add(labels[z, y, x])

    val_counts = zip(*np.unique(labels, return_counts=True))
    val_counts = [x for x in val_counts if not x[0] in bg_labels]   # remove background
    val_counts = sorted(val_counts, key=lambda x:-x[1]) # sort by size
    th = val_counts[0][1] /4    # 1/4 size of the larged connected component (must be lung)
    val = [v for v, c in val_counts if c >= th]
    if len(val) >= 3:
        logging.warn('more than 2 lungs parts detected %d' % len(val))
    binary = np.zeros_like(binary, dtype=np.uint8)
    for v in val:
        binary[labels == v] = 1

    H, W = binary[0].shape
    dilate = int(round(math.sqrt(1.0 * H * W) * dilate))
    #print("DILATE: ", dilate)
    kernel = np.ones((dilate, dilate), dtype=np.int32)

    for i in range(binary.shape[0]):
        cv2.dilate(binary[i], kernel, binary[i])

    #image[binary == 0] = 255
    #image = 255 - image
    #image[binary == 0] = 255
    return binary
    #return image #* binary.astype(image.dtype)


AXIAL = 0
SAGITTAL = 1
CORONAL = 2

VIEWS = [AXIAL, SAGITTAL, CORONAL]
VIEW_NAMES = ['axial', 'sagittal', 'coronal']

AXES_ORDERS = ([0, 1, 2],  # AXIAL
               [2, 1, 0],  # SAGITTAL
               [1, 0, 2])  # CORONAL


def index_view (I, view):
    assert len(I) == 3
    a, b, c = AXES_ORDERS[view]
    return [I[a], I[b], I[c]]

def strip_pad_512 (n, size=512):
    if n >= size:
        from_x = (n-size)/2
        to_x = 0
        n_x = size
        shift_x = from_x
    else:
        from_x = 0
        to_x = (size-n)/2
        n_x = n
        shift_x = -to_x
    return from_x, to_x, n_x, shift_x

class CaseBase (object):
    # self.images
    # self.spacing
    # self.origin   # !!! origin is never transposed!!!
    # self.axes
    # self.vspacing
    def __init__ (self):
        self.uid = None
        self.path = None
        self.images = None      # 3-D array
        self.spacing = None     #
        self.origin = None      # origin never changes
                                # under transposing
        self.view = None
        self.anno = None
        # We save the coefficients for normalize to
        # Hounsfield Units, and keep that updated
        # when normalizing
        self.HU = None          # (intercept, slope)
        self.dcm_z_position = None

        self.orig_origin = None
        self.orig_spacing = None
        self.orig_shape = None
        self.pixel_padding = None
        pass

    def copy_replace_images (self, images):
        case = CaseBase()
        case.uid = self.uid
        case.orig_origin = self.orig_origin
        case.orig_spacing = self.orig_spacing
        case.orig_shape = self.orig_shape
        case.path = self.path
        case.images = images
        case.spacing = self.spacing
        case.view = self.view
        case.origin = self.origin
        case.anno = self.anno
        return case

    def normalizeHU (self):
        assert not self.HU is None
        a, b = self.HU
        self.images *= a
        self.images += b
        self.HU = (1.0, 0)
        if not self.pixel_padding is None:
            self.pixel_padding = self.pixel_padding * a + b
        pass

    def transpose_array (self, view, array):
        if self.view == view:
            return array
        elif self.view == AXIAL and view == SAGITTAL:
            d1, d2 = 0, 2
        elif self.view == AXIAL and view == CORONAL:
            d1, d2 = 0, 1
        elif self.view == SAGITTAL and view == AXIAL:
            d1, d2 = 0, 2
        elif self.view == CORONAL and view == AXIAL:
            d1, d2 = 0, 1
        else:
            assert False
        return np.swapaxes(array, d1, d2)

    def transpose (self, view):
        if self.view == view:
            return self
        elif self.view == AXIAL and view == SAGITTAL:
            d1, d2 = 0, 2
        elif self.view == AXIAL and view == CORONAL:
            d1, d2 = 0, 1
        elif self.view == SAGITTAL and view == AXIAL:
            d1, d2 = 0, 2
        elif self.view == CORONAL and view == AXIAL:
            d1, d2 = 0, 1
        else:
            assert False

        case = CaseBase()
        case.uid = self.uid
        case.orig_origin = self.orig_origin
        case.orig_spacing = self.orig_spacing
        case.orig_shape = self.orig_shape
        case.path = self.path
        case.images = np.swapaxes(self.images, d1, d2)
        assert isinstance(self.spacing, tuple)
        sp = list(self.spacing)
        sp[d1], sp[d2] = sp[d2], sp[d1]
        case.spacing = tuple(sp)
        case.view = view
        case.origin = self.origin
        case.anno = self.anno
        return case

    def round512 (self, size=512):
        target = np.zeros((size,size,size), dtype=self.images.dtype)
        Z, Y, X = self.images.shape
        from_z, to_z, n_z, shift_z = strip_pad_512(Z, size=size)
        from_y, to_y, n_y, shift_y = strip_pad_512(Y, size=size)
        from_x, to_x, n_x, shift_x = strip_pad_512(X, size=size)
        target[to_z:(to_z+n_z),
               to_y:(to_y+n_y),
               to_x:(to_x+n_x)] = self.images[from_z:(from_z+n_z),
                   from_y:(from_y+n_y),
                   from_x:(from_x+n_x)]
        self.origin[0] += shift_z * self.spacing[0]
        self.origin[1] += shift_y * self.spacing[1]
        self.origin[2] += shift_x * self.spacing[2]
        print("off", to_x, to_y, to_z)
        print("len", n_x, n_y, n_z)
        print("shi", shift_x, shift_y, shift_z)
        self.images = target
        pass

    def strip (self, mask, margin1=2, margin2=10):
        z0, z1 = trim_loc(np.sum(mask, axis=(1,2)), margin=margin1)
        y0, y1 = trim_loc(np.sum(mask, axis=(0,2)), margin=margin2)
        x0, x1 = trim_loc(np.sum(mask, axis=(0,1)), margin=margin2)
        self.origin[0] += z0 * self.spacing[0]
        self.origin[1] += y0 * self.spacing[1]
        self.origin[2] += x0 * self.spacing[2]
        self.images = self.images[z0:z1, y0:y1, x0:x1]
        pass

    def round_stride (self, stride=16):
        T, H, W = self.images.shape[:3]
        nT = T / stride * stride
        nH = H / stride * stride
        nW = W / stride * stride
        oT = (T - nT)/2
        oH = (H - nH)/2
        oW = (W - nW)/2
        self.origin[0] += oT * self.spacing[0]
        self.origin[1] += oH * self.spacing[1]
        self.origin[2] += oW * self.spacing[2]
        self.images = self.images[oT:(oT+nT),oH:(oH+nH),oW:(oW+nW)]
        return oT, oH, oW
        pass

    # consider using scipy.ndimage.interpolation
    def rescale (self, slices = None, spacing = None, size = None, method=2):
        # if slices != self.images.shape[0], use method:
        #   0:  adjust slices, so everything is integer and no rounding or approx. is done
        #   1:  do not change slices, use nearest neighbor
        #   2:  do not change slices, use interpolation
        N, H, W = self.images.shape
        case = CaseBase()
        case.uid = self.uid
        case.orig_origin = self.orig_origin
        case.orig_spacing = self.orig_spacing
        case.orig_shape = self.orig_shape
        case.path = self.path
        case.view = self.view
        case.origin = self.origin
        case.anno = self.anno
        case.HU = self.HU
        assert (spacing and not size) or (size and not spacing)
        if (not slices) or (slices == N):
            method = 0
            slices = N
            step = 1
            off = 0
            sp1 = self.spacing[0]
        elif method == 0:   # TODO: need to do actual samping
                # origin under this is not correct due to non-0 off
            step = int(round(N / slices))
            slices = N / step
            off = (N - slices * step) / 2
            sp1 = self.spacing[0] * step
        else:
            off = 0
            step = float(N -1)/ (slices - 1)
            sp1 = self.spacing[0] * step
            pass
        if spacing:
            H = int(round((H-1) * self.spacing[1] / spacing + 1))
            W = int(round((W-1) * self.spacing[2] / spacing + 1))
            resize = (W, H)
            sp2 = spacing
            sp3 = spacing
        elif size:
            sp2 = self.spacing[1] * (H-1) / (size-1)
            sp3 = self.spacing[2] * (W-1) / (size-1)
            resize = (size, size)
            H = size
            W = size
        else:
            resize = None
            _, sp2, sp3 = self.spacing

        case.spacing = (sp1, sp2, sp3)
        case.images = np.zeros((slices, H, W), dtype=np.float32)

        for i in range(slices):
            if method == 0 or method == 1:
                arr = int(round(off))
                image = self.images[arr, :, :]
            elif method == 2:
                L = int(math.floor(off))
                R = int(math.ceil(off))
                if R <= 0:
                    image = self.images[0, :, :]
                elif L >= N-1:
                    image = self.images[N-1, :, :]
                elif R - L < 0.5:   # R == L
                    image = self.images[L, :, :]
                else:
                    image = (self.images[L, :, :]  * (R - off) + self.images[R, :, :] * (off - L)) / (R - L)
                pass

            if resize:
                cv2.resize(image, resize, case.images[i, :, :])
            else:
                case.images[i, :, :] = image
            off += step

        return case

    def rescale3D (self, spacing):
        slices = int(round(self.spacing[0] * (self.images.shape[0] - 1) / spacing + 1))
        return self.rescale(slices, spacing, size=None, method=2)
    pass

    def normalize (self, min=0, max=1, min_th = -1000, max_th = 400):
        assert self.images.dtype == np.float32
        if not min_th is None:
            self.images[self.images < min_th] = min_th
        if not max_th is None:
            self.images[self.images > max_th] = max_th
        m = min_th #np.min(self.images)
        M = max_th #np.max(self.images)
        scale = (1.0 * max - min)/(M - m)
        logging.debug('norm %f %f' % (m, M))
        self.images -= m
        self.images *= scale
        self.images += min
        # recalculate HU 
        #   I:  original image
        #   I': new image
        #   a'I' + b' = aI + b
        #   I' = (I-m) * scale + min
        #      = I*scale + (min - m * scale)
        #   so
        #   a'I*scale + (min - m * scale)*a' + b' = aI + b
        #   
        #   a' = a / scale
        #   b' = b + a'(m * scale -min)
        #      = b + a * (m - min/scale)
        if self.HU:
            a, b = self.HU
            #self.HU = (a * (M -m), b + a * m)
            self.HU = (a / scale, b + a * (m - min/scale))
        pass

    def standardize_color (self):
        self.normalizeHU()
        self.normalize(min_th=-1000,max_th=400,min=0,max=255)
        pass

    def standardize_color16 (self):
        self.normalizeHU()
        self.normalize(min_th=-1000,max_th=400,min=0,max=1400)
        pass

    # return center coordinate
    def world_to_vox (self, world):
        # change view
        # change origin
        z, y, x, r = world
        z0, y0, x0 = self.origin
        cc = (np.array(world[:3])-np.array(self.origin))
        cc = cc[AXES_ORDERS[self.view]]
        spacing = np.array(self.spacing)
        cc = cc / spacing
        rr = r / spacing
        #print "xxx", cc[0], rr[0]
        return  cc, rr

    def picpac_anno (self):
        # !!! annotation is center & radius instead of orign + size in picpac
        if self.anno is None:
            return []
        ALL = []
        nodules = [ self.world_to_vox(world) for world in self.anno]
        C, H, W = self.images.shape
        for (z, y, x), (zr, yr, xr) in nodules:
            first = max(0, int(math.ceil(z - zr)))
            last = min(C-1, int(math.floor(z + zr)))
            if first > last:
                continue
            nod = []
            x /= W
            y /= H 
            for i in range(first, last + 1):
                cos = abs(i - z) / zr
                sin = math.sqrt(1 - cos * cos)
                cyr = yr * sin / H
                cxr = xr * sin / W
                nod.append([i, x, y, cxr, cyr])
                #print 'ellipse', x, y, xr, yr
                #pass
                pass
            ALL.append(nod)
            pass
        return ALL

    def papaya_box (self, box):
        out = [0]*6
        assert self.view == AXIAL
        for i in range(3):
            out[i] = int(round((self.origin[i] + self.spacing[i] * box[i] - self.orig_origin[i]) / self.orig_spacing[i]))
            out[i+3] = int(round((self.origin[i] + self.spacing[i] * box[i+3] - self.orig_origin[i]) / self.orig_spacing[i]))
            pass
        D, _, W = self.orig_shape
        out[0], out[3] = D-out[3], D-out[0]
        out[2], out[5] = W-out[5], W-out[2]
        return out

    def save_gif (self, path, anno=False, aug=2, step=1):
        # must normalize first to [0, 1]
        cube = np.uint8(np.clip(self.images, 0, 255))
        frames = [Image.fromarray(cube[i,:,:]) for i in range(0, cube.shape[0], step)]
        if anno:
            C, H, W = self.images.shape
            annos = self.picpac_anno()
            for nodule in annos:
                for j, x, y, rx, ry in nodule:
                    x *= W
                    y *= H
                    rx *= W * aug
                    ry *= H * aug
                    draw = ImageDraw.Draw(frames[j])
                    draw.ellipse([math.floor(x-rx),
                                  math.floor(y-ry),
                                  math.ceil(x+rx),
                                  math.ceil(y+ry)], outline=255)
                    del draw
            pass
        frames[0].save(path, save_all=True, append_images=frames[1:], duration=0.1, loop=0)
        pass

def group_zrange (dcms):
    zs = [float(dcm.dcm.ImagePositionPatient[2]) for dcm in dcms]
    zs = sorted(zs)
    gap = 1000000
    if len(zs) > 1:
        gap = zs[1] - zs[0]
    return (zs[0], zs[-1], gap)

def regroup_dcms (dcms):
    acq_groups = {}
    for dcm in dcms:
        an = 0
        try:
            an = int(dcm.dcm.AcquisitionNumber)
        except:
            pass
        acq_groups.setdefault(an, []).append(dcm)	
        pass
    groups = acq_groups.values()
    if len(groups) == 1:
        return groups[0]
    # we have multiple acquisitions
    zrs = [group_zrange(group) for group in groups]
    zrs = sorted(zrs, key=lambda x: x[0])
    min_gap = min([zr[2] for zr in zrs])
    gap_th = 2.0 * min_gap
    prev = zrs[0]
    bad = False
    for zr in zrs[1:]:
        gap = zr[0] - prev[1]
        if gap < 0 or gap > gap_th:
            bad = True
            break
        if gap != min_gap:
            logging.error('bad gap')
        prev = zr
    if not bad:
        logging.error('multiple acquisitions merged')
        return dcms
    # return the maximal groups
    gs = max([len(group) for group in groups])
    acq_groups = {k:v for k, v in acq_groups.iteritems() if len(v) == gs}
    key = max(acq_groups.keys())
    group = acq_groups[key]
    print(acq_groups.keys(), key)
    logging.error('found conflicting groups. keeping max acq number, %d out of %d dcms' % (len(group), len(dcms)))
    return group

# All DiCOMs of a UID, organized
class FsCase (CaseBase):
    def __init__ (self, path, regroup = True):
        CaseBase.__init__(self)
        self.path = path
        #self.thumb_path = os.path.join(DATA_DIR, 'thumb', uid)
        # load path
        dcms = []
        for dcm_path in glob(os.path.join(self.path, '*.dcm')):
            dcm = dicom.read_file(dcm_path)
            try:
                boxed = DICOM(dcm)
            except:
                print dcm.filename
                raise
            dcms.append(boxed)
            assert dcms[0].spacing == boxed.spacing
            assert dcms[0].shape == boxed.shape
            assert dcms[0].ori_row == boxed.ori_row
            assert dcms[0].ori_col == boxed.ori_col
            if dcms[0].pixel_padding != boxed.pixel_padding:
                logging.warn('0 padding %s, but now %s, %s' %
                        (dcms[0].pixel_padding, boxed.pixel_padding, dcm.filename))
            #assert dcms[0].HU == boxed.HU
            #print boxed.HU
            pass
        assert len(dcms) >= 2

        if regroup:
            dcms = regroup_dcms(dcms)
        self.pixel_padding = dcms[0].pixel_padding

        dcms.sort(key=lambda x: x.position[2])
        zs = []
        for i in range(1, len(dcms)):
            zs.append(dcms[i].position[2] - dcms[i-1].position[2])
            pass
        zs = np.array(zs)
        z_spacing = np.mean(zs)
        assert z_spacing > 0
        assert np.max(np.abs(zs - z_spacing)) * 1000 < z_spacing

        #self.length = dcms[-1].position[2] - dcms[0].position[2]
        front = dcms[0]
        #self.sizes = (front.shape[0] * front.spacing, front.shape[1] * front.spacing, self.length)
        self.dcms = dcms

        images = np.zeros((len(dcms),)+front.image.shape, dtype=np.float32)
        HU = front.HU
        for i in range(len(dcms)):
            HU2 = dcms[i].HU
            images[i,:,:] = dcms[i].image
            if HU2 != HU:
                logging.warn("HU: (%d) %s => %s, %s" % (i, HU2, HU, dcms[i].dcm.filename))
                images[i, :, :] *= HU2[0] / HU[0]
                images[i, :, :] += (HU2[1] - HU[1])/HU[0]
        self.dcm_z_position = {}
        for dcm in dcms:
            name = os.path.splitext(os.path.basename(dcm.dcm.filename))[0]
            self.dcm_z_position[name] = dcm.position[2] - front.position[2]
            pass
        # spacing   # z, y, x
        self.images = images
        self.spacing = (z_spacing, front.spacing, front.spacing)
        x, y, z = front.position
        self.origin = [z, y, z] #front.location
        self.view = AXIAL
        self.anno = None
        self.HU = HU
        self.orig_origin = copy.deepcopy(self.origin)
        self.orig_spacing = copy.deepcopy(self.spacing)
        self.orig_shape = copy.deepcopy(self.images.shape)
        # sanity check
        pass
    pass

class Case:
    def __init__ (self, uid, regroup = True):
        self.uid = uid
        self.path = os.path.join(DATA_DIR, 'bowl', uid)
        if not os.path.exists(self.path):
            self.path = os.path.join(DATA_DIR, 'samples', uid)
            if not os.path.exists(self.path):
                cc = glob(os.path.join(DATA_DIR, 'lymph', 'data', uid, '*/*'))
                if len(cc) >= 1:
                    self.path = cc[0]
                    assert os.path.exists(self.path)
                    if len(cc) > 1:
                        logging.warn('multiple candidates for ' + uid)
                else:
                    cc = glob(os.path.join(DATA_DIR, 'lymph', 'data', '*/*', uid))
                    if len(cc) >= 1:
                        self.path = cc[0]
                        assert os.path.exists(self.path)
                        if len(cc) > 1:
                            logging.warn('multiple candidates for ' + uid)
                    else:
                        raise Exception('data not found for uid %s' % uid)
            pass
        FsCase.__init__(self, self.path, regroup)
        pass
    pass

LUNA_DIR = os.path.join(ROOT, 'data', 'luna')
#LUNA_DIR = os.path.join('data', 'luna')

def load_luna_dir_layout ():
    lookup = {}
    for i in range(10):
        sub = os.path.join(LUNA_DIR, 'subset%d' % i)
        for f in glob(os.path.join(sub, '*.mhd')):
            bn = os.path.splitext(os.path.basename(f))[0]
            #print bn, "=>", sub
            lookup[bn] = sub
            pass
        pass
    return lookup


def load_luna_csv (filename):
    lines = []
    with open(filename, "rb") as f:
        csvreader = csv.reader(f)
        for line in csvreader:
            lines.append(line)
            return lines
        pass
    pass

def load_luna_annotations ():
    ALL = {}
    with open(os.path.join(LUNA_DIR, 'CSVFILES', 'annotations.csv'), 'r') as f:
        f.next()
        for l in f:
            #print l
            uid, x, y, z, d = l.strip().split(',')
            x = float(x)
            y = float(y)
            z = float(z)
            r = float(d)/2
            ALL.setdefault(uid, []).append((z, y, x, r))
            pass
        pass
    return ALL

def load_luna_meta ():
    cache_path = os.path.join(LUNA_DIR, 'meta.pkl')
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as f:
            return pickle.load(f)
    logging.warn('loading luna meta data')
    meta = (load_luna_dir_layout(),
            load_luna_csv(os.path.join(LUNA_DIR, 'CSVFILES', 'candidates.csv')),
            load_luna_annotations())
    with open(cache_path, 'wb') as f:
        pickle.dump(meta, f)
    return meta

#LUNA_DIR_LOOKUP, _, LUNA_ANNO = load_luna_meta()
LUNA_DIR_LOOKUP = {}
LUNA_ANNO = {}

def worldToVoxelCoord(worldCoord, origin, spacing):
    stretchedVoxelCoord = np.absolute(worldCoord - origin)
    voxelCoord = stretchedVoxelCoord / spacing
    return voxelCoord



# All DiCOMs of a UID, organized
class LunaCase (CaseBase):
    def __init__ (self, uid):
        CaseBase.__init__(self)
        self.uid = uid
        self.path = os.path.join(LUNA_DIR_LOOKUP[uid], uid + '.mhd')
        if not os.path.exists(self.path):
            raise Exception('data not found for uid %s at %s' % (uid, self.path))
        pass
        #self.thumb_path = os.path.join(DATA_DIR, 'thumb', uid)
        # load path
        itkimage = itk.ReadImage(self.path)
        self.HU = (1.0, 0.0)
        self.images = itk.GetArrayFromImage(itkimage).astype(np.float32)
        #print type(self.images), self.images.dtype
        self.origin = list(reversed(itkimage.GetOrigin()))
        self.spacing = list(reversed(itkimage.GetSpacing()))
        self.view = AXIAL
        _, a, b = self.spacing
        self.anno = LUNA_ANNO.get(uid, None)
        assert a == b
        # sanity check
        pass
    pass

def save_mask (path, mask):
    shape = np.array(list(mask.shape), dtype=np.uint32)
    total = mask.size
    totalx = (total +7 )/ 8 * 8
    if totalx == total:
        padded = mask
    else:
        padded = np.zeros((totalx,), dtype=np.uint8)
        padded[:total] = np.reshape(mask, (total,))
        pass
    padded = np.reshape(padded, (totalx/8, 8))
    print padded.shape
    packed = np.packbits(padded)
    print packed.shape
    np.savez_compressed(path, shape, packed)
    pass

def load_mask (path):
    import sys
    saved = np.load(path)
    shape = saved['arr_0']
    D, H, W = shape
    size = D * H * W
    packed = saved['arr_1']
    padded = np.unpackbits(packed)
    binary = padded[:size]
    return np.reshape(binary, [D, H, W])

def is_kaggle (uid):
    return len(uid) == 32
    
def load_case (uid):
    if is_kaggle(uid):
        return Case(uid)
    else:
        return LunaCase(uid)
    pass


def load_8bit_lungs (uid):
    #path = os.path.join('data/cache', uid)
    #if os.path.exists(path):
    #    with open(path, 'rb') as f:
    #        return pickle.load(f)
    case = load_case(uid)

    case.standardize_color()
    cache = os.path.join('maskcache/mask-123.85-0.01/%s.npz' % case.uid)
    binary = None
    if os.path.exists(cache) and os.path.getsize(cache) > 0:
        # load cache
        binary = load_mask(cache)
        assert not binary is None
    if binary is None:
        binary = segment_lung_axial(case.images) #, th=200.85)
        save_mask(cache, binary)
        pass
    case.images[binary==0] = 255
    case.images *= -1
    case.images += 255
    #case = case.rescale3D(1.0)
    #with open(path, 'wb') as f:
    #    pickle.dump(case, f)
    #return case
    return case

def load_8bit_lungs_noseg (uid):
    #path = os.path.join('data/cache', uid)
    #if os.path.exists(path):
    #    with open(path, 'rb') as f:
    #        return pickle.load(f)
    case = load_case(uid)
    case.standardize_color()
    #case.images = segment_lung_axial(case.images) #, th=200.85)
    #case.images *= -1
    #case.images += 255
    #case = case.rescale3D(1.0)
    #with open(path, 'wb') as f:
    #    pickle.dump(case, f)
    #return case
    return case

def load_16bit_lungs_noseg (uid):
    case = load_case(uid)
    case.standardize_color16()
    return case

def segment_lung_axial_v2 (image, th):

    blur = np.copy(image)
    for i in range(blur.shape[0]):
        cv2.blur(blur[i], (5,5), blur[i])

    binary = np.array(blur < th, dtype=np.uint8)
    # 0: body
    # 1: air & lung
    labels = measure.label(binary, background=-1)

    # set air (same cc as corners) -> body
    bg_labels = set()
    for z in [0, -1]:
        for y in [0, -1]:
            for x in [0, -1]:
                bg_labels.add(labels[z, y, x])
    bg_labels = list(bg_labels)
    print(bg_labels)
    if len(bg_labels) > 1:
        logging.warn('bg not connected, detected %d components' % len(bg_labels))
        pass
    for bg_label in bg_labels:
        binary[bg_label == labels] = 0
        pass


    # now binary:
    #   0: non-lung & body tissue in lung
    #   1: lung & holes in body
    for i, sl in enumerate(binary):
        #H, W = sl.shape
        ll = measure.label(sl, background=-1)   # connected components
        # biggest CC should be body
        vv, cc = np.unique(ll, return_counts=True) 
        assert len(vv) > 0
        body_ll = vv[np.argmax(cc)]
        binary[i][ll != body_ll] = 1
        pass

    # set corner again
    labels = measure.label(binary, background=0)
    bg_labels = set([0])
    for z in [0, -1]:
        for y in [0, -1]:
            for x in [0, -1]:
                bg_labels.add(labels[z, y, x])

    val_counts = zip(*np.unique(labels, return_counts=True))
    val_counts = [x for x in val_counts if not x[0] in bg_labels]   # remove background
    val_counts = sorted(val_counts, key=lambda x:-x[1]) # sort by size
    th = val_counts[0][1] /4    # 1/4 size of the larged connected component (must be lung)
    val = [v for v, c in val_counts if c >= th]
    if len(val) >= 3:
        logging.warn('more than 2 lungs parts detected %d' % len(val))
    binary = np.zeros_like(binary, dtype=np.uint8)
    for v in val:
        binary[labels == v] = 1

    H, W = binary[0].shape
    dilate = int(round(math.sqrt(1.0 * H * W) * dilate))
    #print("DILATE: ", dilate)
    kernel = np.ones((dilate, dilate), dtype=np.int32)

    for i in range(binary.shape[0]):
        cv2.dilate(binary[i], kernel, binary[i])

    #image[binary == 0] = 255
    #image = 255 - image
    #image[binary == 0] = 255
    return binary

def load_lungs_mask (uid):
    cache = os.path.join('maskcache/mask-v2/%s.npz' % case.uid)
    binary = None
    if os.path.exists(cache) and os.path.getsize(cache) > 0:
        # load cache
        binary = load_mask(cache)
        assert not binary is None
    if binary is None:
        case = load_case(uid)
        case.normalizeHU()
        binary = segment_lung_axial_v2(case.images) #, th=200.85)
        save_mask(cache, binary)
        pass
    return binary


#def load_lung_mask (uid):
#    #path = os.path.join('data/cache', uid)
#    #if os.path.exists(path):
#    #    with open(path, 'rb') as f:
#    #        return pickle.load(f)
#    case = load_case(uid)
#
#    cache = os.path.join('cache/mask-123.85-0.01/%s.npz' % case.uid)
#    binary = None
#    if os.path.exists(cache) and os.path.getsize(cache) > 0:
#        # load cache
#        binary = load_mask(cache)
#        assert not binary is None
#    if binary is None:
#        case.standardize_color()
#        binary = segment_lung_axial(case.images) #, th=200.85)
#        save_mask(cache, binary)
#        pass
#    case.images = binary.astype(np.float32)
#    return case

def load_fts (path):
    with open(path, 'rb') as f:
        return pickle.load(f)
    pass


def patch_clip_range (x, tx, wx, X):
    if x < 0:   #
        wx += x
        tx -= x
        x = 0
    if x + wx > X:
        d = x + wx - X
        wx -= d
        pass
    return x, tx, wx

def extract_patch_3c (images, z, y, x, size):
    assert len(images.shape) == 3
    _, Y, X = images.shape
    z = int(round(z))
    y = int(round(y))
    x = int(round(x))
    image = get3c(images, z)
    if image is None:
        return None
    ty = 0
    tx = 0
    y -= size/2
    x -= size/2
    wy = size
    wx = size
    print y, ty, wy, x, tx, wx
    y, ty, wy = patch_clip_range(y, ty, wy, Y)
    x, tx, wx = patch_clip_range(x, tx, wx, X)
    # now do overlap
    patch = np.zeros((size, size, 3), dtype=image.dtype)
    print y, ty, wy, x, tx, wx
    patch[ty:(ty+wy),tx:(tx+wx),:] = image[y:(y+wy),x:(x+wx),:]
    return patch