import os
import shutil
import numpy as np
import sys
# sys.path.append('./')
# import config_training
from config_training import config 
# config = config_training.config
from scipy.io import loadmat
import numpy as np
import h5py
import pandas
import scipy
from scipy.ndimage.interpolation import zoom
from skimage import measure
import SimpleITK as sitk
from scipy.ndimage.morphology import binary_dilation,generate_binary_structure
from skimage.morphology import convex_hull_image
import pandas
from multiprocessing import Pool
from functools import partial

import warnings

def resample(imgs, spacing, new_spacing,order=2):
    if len(imgs.shape)==3:
        new_shape = np.round(imgs.shape * spacing / new_spacing)
        true_spacing = spacing * imgs.shape / new_shape
        resize_factor = new_shape / imgs.shape
        imgs = zoom(imgs, resize_factor, mode = 'nearest',order=order)
        return imgs, true_spacing
    elif len(imgs.shape)==4:
        n = imgs.shape[-1]
        newimg = []
        for i in range(n):
            slice = imgs[:,:,:,i]
            newslice,true_spacing = resample(slice,spacing,new_spacing)
            newimg.append(newslice)
        newimg=np.transpose(np.array(newimg),[1,2,3,0])
        return newimg,true_spacing
    else:
        raise ValueError('wrong shape')
def worldToVoxelCoord(worldCoord, origin, spacing):
     
    stretchedVoxelCoord = np.absolute(worldCoord - origin)
    voxelCoord = stretchedVoxelCoord / spacing
    return voxelCoord

def load_itk_image(filename):
    with open(filename) as f:
        contents = f.readlines()
        line = [k for k in contents if k.startswith('TransformMatrix')][0]
        transformM = np.array(line.split(' = ')[1].split(' ')).astype('float')
        transformM = np.round(transformM)
        if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])):
            isflip = True
        else:
            isflip = False

    itkimage = sitk.ReadImage(filename)
    numpyImage = sitk.GetArrayFromImage(itkimage)
     
    numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
    numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
     
    return numpyImage, numpyOrigin, numpySpacing,isflip

def process_mask(mask):
    convex_mask = np.copy(mask)
    for i_layer in range(convex_mask.shape[0]):
        mask1  = np.ascontiguousarray(mask[i_layer])
        if np.sum(mask1)>0:
            mask2 = convex_hull_image(mask1)
            if np.sum(mask2)>1.5*np.sum(mask1):
                mask2 = mask1
        else:
            mask2 = mask1
        convex_mask[i_layer] = mask2
    struct = generate_binary_structure(3,1)  
    dilatedMask = binary_dilation(convex_mask,structure=struct,iterations=10) 
    return dilatedMask


def lumTrans(img):
    lungwin = np.array([-1200.,600.])
    newimg = (img-lungwin[0])/(lungwin[1]-lungwin[0])
    newimg[newimg<0]=0
    newimg[newimg>1]=1
    newimg = (newimg*255).astype('uint8')
    return newimg

def binarize_per_slice(image, spacing, intensity_th=-600, sigma=1, area_th=30, eccen_th=0.99, bg_patch_size=10):
    bw = np.zeros(image.shape, dtype=bool)
    
    # prepare a mask, with all corner values set to nan
    image_size = image.shape[1]
    grid_axis = np.linspace(-image_size/2+0.5, image_size/2-0.5, image_size)
    x, y = np.meshgrid(grid_axis, grid_axis)
    d = (x**2+y**2)**0.5
    nan_mask = (d<image_size/2).astype(float)
    nan_mask[nan_mask == 0] = np.nan
    for i in range(image.shape[0]):
        # Check if corner pixels are identical, if so the slice  before Gaussian filtering
        if len(np.unique(image[i, 0:bg_patch_size, 0:bg_patch_size])) == 1:
            current_bw = scipy.ndimage.filters.gaussian_filter(np.multiply(image[i].astype('float32'), nan_mask), sigma, truncate=2.0) < intensity_th
        else:
            current_bw = scipy.ndimage.filters.gaussian_filter(image[i].astype('float32'), sigma, truncate=2.0) < intensity_th
        
        # select proper components
        label = measure.label(current_bw)
        properties = measure.regionprops(label)
        valid_label = set()
        for prop in properties:
            if prop.area * spacing[1] * spacing[2] > area_th and prop.eccentricity < eccen_th:
                valid_label.add(prop.label)
        current_bw = np.in1d(label, list(valid_label)).reshape(label.shape)
        bw[i] = current_bw
        
    return bw

def all_slice_analysis(bw, spacing, cut_num=0, vol_limit=[0.68, 8.2], area_th=6e3, dist_th=62):
    # in some cases, several top layers need to be removed first
    if cut_num > 0:
        bw0 = np.copy(bw)
        bw[-cut_num:] = False
    label = measure.label(bw, connectivity=1)
    # remove components access to corners
    mid = int(label.shape[2] / 2)
    bg_label = set([label[0, 0, 0], label[0, 0, -1], label[0, -1, 0], label[0, -1, -1], \
                    label[-1-cut_num, 0, 0], label[-1-cut_num, 0, -1], label[-1-cut_num, -1, 0], label[-1-cut_num, -1, -1], \
                    label[0, 0, mid], label[0, -1, mid], label[-1-cut_num, 0, mid], label[-1-cut_num, -1, mid]])
    for l in bg_label:
        label[label == l] = 0
        
    # select components based on volume
    properties = measure.regionprops(label)
    for prop in properties:
        if prop.area * spacing.prod() < vol_limit[0] * 1e6 or prop.area * spacing.prod() > vol_limit[1] * 1e6:
            label[label == prop.label] = 0
            
    # prepare a distance map for further analysis
    x_axis = np.linspace(-label.shape[1]/2+0.5, label.shape[1]/2-0.5, label.shape[1]) * spacing[1]
    y_axis = np.linspace(-label.shape[2]/2+0.5, label.shape[2]/2-0.5, label.shape[2]) * spacing[2]
    x, y = np.meshgrid(x_axis, y_axis)
    d = (x**2+y**2)**0.5
    vols = measure.regionprops(label)
    valid_label = set()
    # select components based on their area and distance to center axis on all slices
    for vol in vols:
        single_vol = label == vol.label
        slice_area = np.zeros(label.shape[0])
        min_distance = np.zeros(label.shape[0])
        for i in range(label.shape[0]):
            slice_area[i] = np.sum(single_vol[i]) * np.prod(spacing[1:3])
            min_distance[i] = np.min(single_vol[i] * d + (1 - single_vol[i]) * np.max(d))
        
        if np.average([min_distance[i] for i in range(label.shape[0]) if slice_area[i] > area_th]) < dist_th:
            valid_label.add(vol.label)
            
    bw = np.in1d(label, list(valid_label)).reshape(label.shape)
    
    # fill back the parts removed earlier
    if cut_num > 0:
        # bw1 is bw with removed slices, bw2 is a dilated version of bw, part of their intersection is returned as final mask
        bw1 = np.copy(bw)
        bw1[-cut_num:] = bw0[-cut_num:]
        bw2 = np.copy(bw)
        bw2 = scipy.ndimage.binary_dilation(bw2, iterations=cut_num)
        bw3 = bw1 & bw2
        label = measure.label(bw, connectivity=1)
        label3 = measure.label(bw3, connectivity=1)
        l_list = list(set(np.unique(label)) - {0})
        valid_l3 = set()
        for l in l_list:
            indices = np.nonzero(label==l)
            l3 = label3[indices[0][0], indices[1][0], indices[2][0]]
            if l3 > 0:
                valid_l3.add(l3)
        bw = np.in1d(label3, list(valid_l3)).reshape(label3.shape)
    
    return bw, len(valid_label)

def fill_hole(bw):
    # fill 3d holes
    label = measure.label(~bw)
    # idendify corner components
    bg_label = set([label[0, 0, 0], label[0, 0, -1], label[0, -1, 0], label[0, -1, -1], \
                    label[-1, 0, 0], label[-1, 0, -1], label[-1, -1, 0], label[-1, -1, -1]])
    bw = ~np.in1d(label, list(bg_label)).reshape(label.shape)
    
    return bw

def two_lung_only(bw, spacing, max_iter=22, max_ratio=4.8):    
    def extract_main(bw, cover=0.95):
        for i in range(bw.shape[0]):
            current_slice = bw[i]
            label = measure.label(current_slice)
            properties = measure.regionprops(label)
            properties.sort(key=lambda x: x.area, reverse=True)
            area = [prop.area for prop in properties]
            count = 0
            sum = 0
            while sum < np.sum(area)*cover:
                sum = sum+area[count]
                count = count+1
            filter = np.zeros(current_slice.shape, dtype=bool)
            for j in range(count):
                bb = properties[j].bbox
                filter[bb[0]:bb[2], bb[1]:bb[3]] = filter[bb[0]:bb[2], bb[1]:bb[3]] | properties[j].convex_image
            bw[i] = bw[i] & filter
           
        label = measure.label(bw)
        properties = measure.regionprops(label)
        properties.sort(key=lambda x: x.area, reverse=True)
        bw = label==properties[0].label

        return bw
    
    def fill_2d_hole(bw):
        for i in range(bw.shape[0]):
            current_slice = bw[i]
            label = measure.label(current_slice)
            properties = measure.regionprops(label)
            for prop in properties:
                bb = prop.bbox
                current_slice[bb[0]:bb[2], bb[1]:bb[3]] = current_slice[bb[0]:bb[2], bb[1]:bb[3]] | prop.filled_image
            bw[i] = current_slice

        return bw
    
    found_flag = False
    iter_count = 0
    bw0 = np.copy(bw)
    while not found_flag and iter_count < max_iter:
        label = measure.label(bw, connectivity=2)
        properties = measure.regionprops(label)
        properties.sort(key=lambda x: x.area, reverse=True)
        if len(properties) > 1 and properties[0].area/properties[1].area < max_ratio:
            found_flag = True
            bw1 = label == properties[0].label
            bw2 = label == properties[1].label
        else:
            bw = scipy.ndimage.binary_erosion(bw)
            iter_count = iter_count + 1
    
    if found_flag:
        d1 = scipy.ndimage.morphology.distance_transform_edt(bw1 == False, sampling=spacing)
        d2 = scipy.ndimage.morphology.distance_transform_edt(bw2 == False, sampling=spacing)
        bw1 = bw0 & (d1 < d2)
        bw2 = bw0 & (d1 > d2)
                
        bw1 = extract_main(bw1)
        bw2 = extract_main(bw2)
        
    else:
        bw1 = bw0
        bw2 = np.zeros(bw.shape).astype('bool')
        
    bw1 = fill_2d_hole(bw1)
    bw2 = fill_2d_hole(bw2)
    bw = bw1 | bw2

    return bw1, bw2, bw

def step1_python_tianchi(case_path):
    # case = load_scan(case_path)
    # case_pixels, spacing = get_pixels_hu(case)
    ''' For the mhd file reader '''
    resolution = np.array([1,1,1])
    sliceim,origin,spacing,isflip = load_itk_image(case_path+'.mhd')
    if isflip:
        sliceim = sliceim[:,::-1,::-1]
        print('flip!')
    # sliceim = lumTrans(sliceim)
    # sliceim1,_ = resample(sliceim,spacing,resolution,order=1)
    case_pixels = np.array(sliceim)

    bw = binarize_per_slice(case_pixels, spacing)
    flag = 0
    cut_num = 0
    cut_step = 2
    bw0 = np.copy(bw)
    while flag == 0 and cut_num < bw.shape[0]:
        bw = np.copy(bw0)
        bw, flag = all_slice_analysis(bw, spacing, cut_num=cut_num, vol_limit=[0.68,7.5])
        cut_num = cut_num + cut_step

    bw = fill_hole(bw)
    bw1, bw2, bw = two_lung_only(bw, spacing)
    return case_pixels, bw1, bw2, spacing, origin, isflip

def savenpy(id, annos, filelist, data_path, prep_folder):    
    resolution = np.array([1, 1, 1])
    name = filelist[id]
    im, m1, m2, spacing, origin, isflip = step1_python_tianchi(os.path.join(data_path, name))
    missingmask = False
    if os.path.exists(os.path.join(prep_folder,name+'_clean.npy')) and \
        os.path.exists(os.path.join(prep_folder,name+'_originbox.npy')) and \
        os.path.exists(os.path.join(prep_folder,name+'_spacing.npy')) and \
        os.path.exists(os.path.join(prep_folder,name+'_origin.npy')) and \
        os.path.exists(os.path.join(prep_folder,name+'_label.npy')):
        if not isflip:
            print 'skip', name
            return
        else:
            missingmask = True

    print 'process', name
    label = annos[annos[:,0]==name]
    # label = label.astype('float')
    label = label[:, [3,1,2,4]].astype('float') # z, y, x, d
    
    
    Mask = m1 + m2
    
    newshape = np.round(np.array(Mask.shape) * spacing / resolution)
    xx,yy,zz = np.where(Mask)
    if xx.size == 0 or yy.size == 0 or zz.size == 0:
        print name 
        assert 1 == 0

    box = np.array([[np.min(xx), np.max(xx)], [np.min(yy), np.max(yy)], [np.min(zz), np.max(zz)]])
    box = box * np.expand_dims(spacing, 1) / np.expand_dims(resolution, 1)
    box = np.floor(box).astype('int')
    margin = 5
    extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T
    extendbox = extendbox.astype('int')
    if extendbox[0,0] == extendbox[0,1] or extendbox[1,0] == extendbox[1,1] or extendbox[2,0] == extendbox[2,1]:
        print name
        assert 1==0

    convex_mask = m1
    dm1 = process_mask(m1)
    dm2 = process_mask(m2)
    dilatedMask = dm1+dm2
    Mask = m1+m2
    if missingmask:
        np.save(os.path.join(prep_folder,name+'_mask.npy'), Mask)
        print 'skip', name
        return
    extramask = dilatedMask - Mask
    bone_thresh = 210
    pad_value = 170
    im[np.isnan(im)]=-2000
    sliceim = lumTrans(im)
    sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8')
    bones = sliceim*extramask>bone_thresh
    sliceim[bones] = pad_value
    sliceim1,_ = resample(sliceim,spacing,resolution,order=1)
    sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1],
                extendbox[1,0]:extendbox[1,1],
                extendbox[2,0]:extendbox[2,1]]
    sliceim = sliceim2[np.newaxis,...]
    np.save(os.path.join(prep_folder,name+'_clean.npy'), sliceim)
    np.save(os.path.join(prep_folder,name+'_originbox.npy'), extendbox)
    np.save(os.path.join(prep_folder,name+'_spacing.npy'), spacing)
    np.save(os.path.join(prep_folder,name+'_origin.npy'), origin)
    print im.shape, '_clean', sliceim.shape, '_originbox', extendbox.shape, '_space', spacing, '_origin', origin

    this_annos = np.copy(annos[annos[:,0]==name])
    label = []
    print 'label', this_annos.shape, name
    if len(this_annos)>0:
        
        for c in this_annos:
            pos = worldToVoxelCoord(c[1:4][::-1],origin=origin,spacing=spacing)
            if isflip:
                pos[1:] = Mask.shape[1:3]-pos[1:]
            label.append(np.concatenate([pos,[c[4]/spacing[1]]]))
        
    label = np.array(label)
    if len(label)==0:
        label2 = np.array([[0,0,0,0]])
    else:
        label2 = np.copy(label).T
        label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1)
        label2[3] = label2[3]*spacing[1]/resolution[1]
        label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1)
        label2 = label2[:4].T
    np.save(os.path.join(prep_folder,name+'_label.npy'),label2)
    print name

def full_prep(train=True, val=True, test=True):
    warnings.filterwarnings("ignore")
    #preprocess_result_path = './prep_result'
    train_prep_folder = config['train_preprocess_result_path']
    val_prep_folder = config['val_preprocess_result_path']
    test_prep_folder = config['test_preprocess_result_path']
    
    train_data_path = config['train_data_path']
    val_data_path = config['val_data_path']
    test_data_path = config['test_data_path']
    
    finished_flag = '.flag_preptianchi'
    
    if not os.path.exists(finished_flag):
        trainlabelfiles = config['train_annos_path']
        vallabelfiles = config['val_annos_path']
        testlabelfiles = config['test_annos_path']

        traincontent = np.array(pandas.read_csv(trainlabelfiles))
        traincontent = traincontent[traincontent[:, 0] != np.nan]
        trainalllabel = traincontent[1:, :] # filename, x, y, z, d
        trainfilelist = []
        for f in os.listdir(config['train_data_path']):
            if f.endswith('.mhd'):
                if f[:-4] in config['black_list']:
                    continue
                trainfilelist.append(f[:-4])

        valcontent = np.array(pandas.read_csv(vallabelfiles))
        valcontent = valcontent[valcontent[:, 0] != np.nan]
        valalllabel = valcontent[1:, :] # filename, x, y, z, d
        valfilelist = []
        for f in os.listdir(config['val_data_path']):
            if f.endswith('.mhd'):
                if f[:-4] in config['black_list']:
                    continue
                valfilelist.append(f[:-4])

        testcontent = np.array(pandas.read_csv(testlabelfiles))
        testcontent = testcontent[testcontent[:, 0] != np.nan]
        testalllabel = testcontent[1:, :] # filename, x, y, z, d
        testfilelist = []
        for f in os.listdir(config['test_data_path']):
            if f.endswith('.mhd'):
                if f[:-4] in config['black_list']:
                    continue
                testfilelist.append(f[:-4])

        if not os.path.exists(train_prep_folder):
            os.mkdir(train_prep_folder)
        if not os.path.exists(val_prep_folder):
            os.mkdir(val_prep_folder)
        if not os.path.exists(test_prep_folder):
            os.mkdir(test_prep_folder)
        #eng.addpath('preprocessing/',nargout=0)
        if train:
            print('starting train preprocessing')
            pool = Pool(10)
            partial_savenpy = partial(savenpy, annos=trainalllabel, filelist=trainfilelist, data_path=train_data_path, prep_folder=train_prep_folder)
            N = len(trainfilelist)
            savenpy(1)
            _ = pool.map(partial_savenpy, range(N))
            print('end train preprocessing')
        if val:
            print('starting val preprocessing')
            partial_savenpy = partial(savenpy, annos=valalllabel, filelist=valfilelist, data_path=val_data_path, prep_folder=val_prep_folder)
            N = len(valfilelist)
            savenpy(1)
            _ = pool.map(partial_savenpy, range(N))
            print('end val preprocessing')
        if test:
            print('starting test preprocessing')
            partial_savenpy = partial(savenpy, annos=testalllabel, filelist=testfilelist, data_path=test_data_path, prep_folder=test_prep_folder)
            N = len(testfilelist)
            savenpy(1)
            _ = pool.map(partial_savenpy, range(N))
            pool.close()
            pool.join()
            print('end test preprocessing')
    f= open(finished_flag,"w+")        
    
def splitvaltestcsv():
    testfiles = []
    for f in os.listdir(config['test_data_path']):
        if f.endswith('.mhd'):
            testfiles.append(f[:-4])
    valcsvlines = []
    testcsvlines = []
    import csv 
    valf = open(config['val_annos_path'], 'r')
    valfcsv = csv.reader(valf)
    for line in valfcsv:
        if line[0] in testfiles:
            testcsvlines.append(line)
        else:
            valcsvlines.append(line)
    valf.close()
    testf = open(config['test_annos_path']+'annotations.csv', 'w')
    testfcsv = csv.writer(testf)
    for line in testcsvlines:
        testfcsv.writerow(line)
    testf.close()
    valf = open(config['val_annos_path'], 'w')
    valfcsv = csv.writer(valf)
    for line in valcsvlines:
        valfcsv.writerow(line)
    valf.close()

def savenpy_luna(id, annos, filelist, luna_segment, luna_data,savepath):
    islabel = True
    isClean = True
    resolution = np.array([1,1,1])
#     resolution = np.array([2,2,2])
    name = filelist[id]
    
    sliceim,origin,spacing,isflip = load_itk_image(os.path.join(luna_data,name+'.mhd'))

    Mask,origin,spacing,isflip = load_itk_image(os.path.join(luna_segment,name+'.mhd'))
    if isflip:
        Mask = Mask[:,::-1,::-1]
    newshape = np.round(np.array(Mask.shape)*spacing/resolution).astype('int')
    m1 = Mask==3
    m2 = Mask==4
    Mask = m1+m2
    
    xx,yy,zz= np.where(Mask)
    box = np.array([[np.min(xx),np.max(xx)],[np.min(yy),np.max(yy)],[np.min(zz),np.max(zz)]])
    box = box*np.expand_dims(spacing,1)/np.expand_dims(resolution,1)
    box = np.floor(box).astype('int')
    margin = 5
    extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T

    this_annos = np.copy(annos[annos[:,0]==(name)])        

    if isClean:
        convex_mask = m1
        dm1 = process_mask(m1)
        dm2 = process_mask(m2)
        dilatedMask = dm1+dm2
        Mask = m1+m2

        extramask = dilatedMask ^ Mask
        bone_thresh = 210
        pad_value = 170

        if isflip:
            sliceim = sliceim[:,::-1,::-1]
            print('flip!')
        sliceim = lumTrans(sliceim)
        sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8')
        bones = (sliceim*extramask)>bone_thresh
        sliceim[bones] = pad_value
        
        sliceim1,_ = resample(sliceim,spacing,resolution,order=1)
        sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1],
                    extendbox[1,0]:extendbox[1,1],
                    extendbox[2,0]:extendbox[2,1]]
        sliceim = sliceim2[np.newaxis,...]
        np.save(os.path.join(savepath, name+'_clean.npy'), sliceim)
        np.save(os.path.join(savepath, name+'_spacing.npy'), spacing)
        np.save(os.path.join(savepath, name+'_extendbox.npy'), extendbox)
        np.save(os.path.join(savepath, name+'_origin.npy'), origin)
        np.save(os.path.join(savepath, name+'_mask.npy'), Mask)

    if islabel:
        this_annos = np.copy(annos[annos[:,0]==(name)])
        label = []
        if len(this_annos)>0:
            
            for c in this_annos:
                pos = worldToVoxelCoord(c[1:4][::-1],origin=origin,spacing=spacing)
                if isflip:
                    pos[1:] = Mask.shape[1:3]-pos[1:]
                label.append(np.concatenate([pos,[c[4]/spacing[1]]]))
            
        label = np.array(label)
        if len(label)==0:
            label2 = np.array([[0,0,0,0]])
        else:
            label2 = np.copy(label).T
            label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1)
            label2[3] = label2[3]*spacing[1]/resolution[1]
            label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1)
            label2 = label2[:4].T
        np.save(os.path.join(savepath,name+'_label.npy'), label2)
        
    print(name)

def preprocess_luna():
    luna_segment = config['luna_segment']
    savepath = config['preprocess_result_path']
    luna_data = config['luna_data']
    luna_label = config['luna_label']
    finished_flag = '.flag_preprocessluna'
    print('starting preprocessing luna')
    if not os.path.exists(finished_flag):
        annos = np.array(pandas.read_csv(luna_label))
        pool = Pool()
        if not os.path.exists(savepath):
            os.mkdir(savepath)
        for setidx in xrange(10):
            print 'process subset', setidx
            filelist = [f.split('.mhd')[0] for f in os.listdir(luna_data+'subset'+str(setidx)) if f.endswith('.mhd') ]
            if not os.path.exists(savepath+'subset'+str(setidx)):
                os.mkdir(savepath+'subset'+str(setidx))
            partial_savenpy_luna = partial(savenpy_luna, annos=annos, filelist=filelist,
                                       luna_segment=luna_segment, luna_data=luna_data+'subset'+str(setidx)+'/', 
                                       savepath=savepath+'subset'+str(setidx)+'/')
            N = len(filelist)
            #savenpy(1)
            _=pool.map(partial_savenpy_luna,range(N))
        pool.close()
        pool.join()
    print('end preprocessing luna')
    f= open(finished_flag,"w+")


if __name__=='__main__':
    preprocess_luna()