# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Preprocessing script.

This script walks over the directories and dump the frames into a csv file
"""
import os
import csv
import sys
import random
import scipy
import numpy as np
import dicom
from skimage import io, transform
from joblib import Parallel, delayed
import dill

def mkdir(fname):
   try:
       os.mkdir(fname)
   except:
       pass

def get_frames(root_path):
   """Get path to all the frame in view SAX and contain complete frames"""
   ret = []
   for root, _, files in os.walk(root_path):
       root=root.replace('\\','/')
       files=[s for s in files if ".dcm" in s]
       if len(files) == 0 or not files[0].endswith(".dcm") or root.find("sax") == -1:
           continue
       prefix = files[0].rsplit('-', 1)[0]
       fileset = set(files)
       expected = ["%s-%04d.dcm" % (prefix, i + 1) for i in range(30)]
       if all(x in fileset for x in expected):
           ret.append([root + "/" + x for x in expected])
   # sort for reproduciblity
   return sorted(ret, key = lambda x: x[0])


def get_label_map(fname):
   labelmap = {}
   fi = open(fname)
   fi.readline()
   for line in fi:
       arr = line.split(',')
       labelmap[int(arr[0])] = line
   return labelmap


def write_label_csv(fname, frames, label_map):
   fo = open(fname, "w")
   for lst in frames:
       index = int(lst[0].split("/")[3])
       if label_map is not None:
           fo.write(label_map[index])
       else:
           fo.write("%d,0,0\n" % index)
   fo.close()


def get_data(lst,preproc):
   data = []
   result = []
   for path in lst:
       f = dicom.read_file(path)
       img = preproc(f.pixel_array.astype(float) / np.max(f.pixel_array))
       dst_path = path.rsplit(".", 1)[0] + ".64x64.jpg"
       scipy.misc.imsave(dst_path, img)
       result.append(dst_path)
       data.append(img)
   data = np.array(data, dtype=np.uint8)
   data = data.reshape(data.size)
   data = np.array(data, dtype=np.str_)
   data = data.reshape(data.size)
   return [data,result]


def write_data_csv(fname, frames, preproc):
   """Write data to csv file"""
   fdata = open(fname, "w")
   dr = Parallel()(delayed(get_data)(lst,preproc) for lst in frames)
   data,result = zip(*dr)
   for entry in data:
      fdata.write(','.join(entry)+'\r\n')
   print("All finished, %d slices in total" % len(data))
   fdata.close()
   result = np.ravel(result)
   return result


def crop_resize(img, size):
   """crop center and resize"""
   if img.shape[0] < img.shape[1]:
       img = img.T
   # we crop image from center
   short_egde = min(img.shape[:2])
   yy = int((img.shape[0] - short_egde) / 2)
   xx = int((img.shape[1] - short_egde) / 2)
   crop_img = img[yy : yy + short_egde, xx : xx + short_egde]
   # resize to 64, 64
   resized_img = transform.resize(crop_img, (size, size))
   resized_img *= 255
   return resized_img.astype("uint8")


def local_split(train_index):
   random.seed(0)
   train_index = set(train_index)
   all_index = sorted(train_index)
   num_test = int(len(all_index) / 3)
   random.shuffle(all_index)
   train_set = set(all_index[num_test:])
   test_set = set(all_index[:num_test])
   return train_set, test_set


def split_csv(src_csv, split_to_train, train_csv, test_csv):
   ftrain = open(train_csv, "w")
   ftest = open(test_csv, "w")
   cnt = 0
   for l in open(src_csv):
       if split_to_train[cnt]:
           ftrain.write(l)
       else:
           ftest.write(l)
       cnt = cnt + 1
   ftrain.close()
   ftest.close()

# Load the list of all the training frames, and shuffle them
# Shuffle the training frames
random.seed(10)
train_frames = get_frames("./data/train")
random.shuffle(train_frames)
validate_frames = get_frames("./data/validate")

# Write the corresponding label information of each frame into file.
write_label_csv("./train-label.csv", train_frames, get_label_map("./data/train.csv"))
write_label_csv("./validate-label.csv", validate_frames, None)

# Dump the data of each frame into a CSV file, apply crop to 64 preprocessor
train_lst = write_data_csv("./train-64x64-data.csv", train_frames, lambda x: crop_resize(x, 64))
valid_lst = write_data_csv("./validate-64x64-data.csv", validate_frames, lambda x: crop_resize(x, 64))

# Generate local train/test split, which you could use to tune your model locally.
train_index = np.loadtxt("./train-label.csv", delimiter=",")[:,0].astype("int")
train_set, test_set = local_split(train_index)
split_to_train = [x in train_set for x in train_index]
split_csv("./train-label.csv", split_to_train, "./local_train-label.csv", "./local_test-label.csv")
split_csv("./train-64x64-data.csv", split_to_train, "./local_train-64x64-data.csv", "./local_test-64x64-data.csv")