#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Created on Wed Aug  7 17:50:47 2019

@author: lferiani

#%% import statements
import numpy as np
import pandas as pd
from numpy.fft import fft2, ifft2, fftshift

#%% constants

# dictionary to go from camera name to channel
# to be updated as we get more copies of the LoopBio rig
CAM2CH_DICT_legacy = {"22594549":'Ch1',

CAM2CH_DICT = {"22956818":'Ch1', # Hydra01
               "22956839":'Ch1', # Hydra02
               "22956814":'Ch1', # Hydra03
               "22956812":'Ch1', # Hydra04
               "22594559":'Ch1', # Hydra05

# this can't be a nice and simple dictionary because people may want to use 
# this info in the other direction

CAM2CH_list = [('22956818', 'Ch1', 'Hydra01'), # Hydra01
               ('22956816', 'Ch2', 'Hydra01'),
               ('22956813', 'Ch3', 'Hydra01'),
               ('22956805', 'Ch4', 'Hydra01'),
               ('22956807', 'Ch5', 'Hydra01'),
               ('22956832', 'Ch6', 'Hydra01'),
               ('22956839', 'Ch1', 'Hydra02'), # Hydra02
               ('22956837', 'Ch2', 'Hydra02'),
               ('22956836', 'Ch3', 'Hydra02'),
               ('22956829', 'Ch4', 'Hydra02'),
               ('22956822', 'Ch5', 'Hydra02'),
               ('22956806', 'Ch6', 'Hydra02'),
               ('22956814', 'Ch1', 'Hydra03'), # Hydra03
               ('22956833', 'Ch2', 'Hydra03'),
               ('22956819', 'Ch3', 'Hydra03'),
               ('22956827', 'Ch4', 'Hydra03'),
               ('22956823', 'Ch5', 'Hydra03'),
               ('22956840', 'Ch6', 'Hydra03'),
               ('22956812', 'Ch1', 'Hydra04'), # Hydra04
               ('22956834', 'Ch2', 'Hydra04'),
               ('22956817', 'Ch3', 'Hydra04'),
               ('22956811', 'Ch4', 'Hydra04'),
               ('22956831', 'Ch5', 'Hydra04'),
               ('22956809', 'Ch6', 'Hydra04'),
               ('22594559', 'Ch1', 'Hydra05'), # Hydra05
               ('22594547', 'Ch2', 'Hydra05'),
               ('22594546', 'Ch3', 'Hydra05'),
               ('22436248', 'Ch4', 'Hydra05'),
               ('22594549', 'Ch5', 'Hydra05'),
               ('22594548', 'Ch6', 'Hydra05')]

CAM2CH_df = pd.DataFrame(CAM2CH_list, 
                         columns=['camera_serial', 'channel', 'rig'])

# dictionaries to go from channel/(col, row) to well name.
# there will be many as it depends on total number of wells, upright/upsidedown,
# and in case of the 48wp how many wells in the fov

UPRIGHT_48WP_669999 = pd.DataFrame.from_dict({ ('Ch1',0):['A1','B1','C1'],

UPRIGHT_96WP = pd.DataFrame.from_dict({('Ch1',0):[ 'A1', 'B1', 'C1', 'D1'],
                                       ('Ch1',1):[ 'A2', 'B2', 'C2', 'D2'],
                                       ('Ch1',2):[ 'A3', 'B3', 'C3', 'D3'],
                                       ('Ch1',3):[ 'A4', 'B4', 'C4', 'D4'],
                                       ('Ch2',0):[ 'E1', 'F1', 'G1', 'H1'],
                                       ('Ch2',1):[ 'E2', 'F2', 'G2', 'H2'],
                                       ('Ch2',2):[ 'E3', 'F3', 'G3', 'H3'],
                                       ('Ch2',3):[ 'E4', 'F4', 'G4', 'H4'],
                                       ('Ch3',0):[ 'A5', 'B5', 'C5', 'D5'],
                                       ('Ch3',1):[ 'A6', 'B6', 'C6', 'D6'],
                                       ('Ch3',2):[ 'A7', 'B7', 'C7', 'D7'],
                                       ('Ch3',3):[ 'A8', 'B8', 'C8', 'D8'],
                                       ('Ch4',0):[ 'E5', 'F5', 'G5', 'H5'],
                                       ('Ch4',1):[ 'E6', 'F6', 'G6', 'H6'],
                                       ('Ch4',2):[ 'E7', 'F7', 'G7', 'H7'],
                                       ('Ch4',3):[ 'E8', 'F8', 'G8', 'H8'],
                                       ('Ch5',0):[ 'A9', 'B9', 'C9', 'D9'],
                                       ('Ch6',0):[ 'E9', 'F9', 'G9', 'H9'],

#%% functions

def get_mwp_map(total_n_wells, whichsideup):
    Given a total number of wells, and whether the multiwell plate
    is upright or upside-down, returns a dataframe with the correct 
    channel/row/column -> well_name mapping
    (this works on the Hydra imaging systems - by LoopBio Gmbh - used in Andre
    Brown's lab)
    if total_n_wells==48 and whichsideup=='upright':
        return UPRIGHT_48WP_669999
    elif total_n_wells==96 and whichsideup=='upright':
        return UPRIGHT_96WP
        raise ValueError('This case has not been coded yet. ' + \
                         'Please contact the devs or open a feature request on GitHub.')

def serial2rigchannel(camera_serial):
    Takes camera serial number, returns a (rig, channel) tuple
    out = CAM2CH_df[CAM2CH_df['camera_serial']==camera_serial]
    if len(out) == 0:
        raise ValueError('{} unknown as camera serial string'.format(camera_serial))
    elif len(out) == 1:
        return tuple(out[['rig','channel']].values[0])
        raise Exception('Multiple hits for {}. split_fov/helper.py corrupted?'.format(camera_serial))

def serial2channel(camera_serial):
    Takes camera serial number, returns the channel
    return serial2rigchannel(camera_serial)[1]

def parse_camera_serial(filename):
    import re
    regex = r"(?<=20\d{6}\_\d{6}\.)\d{8}"
    camera_serial = re.findall(regex, str(filename).lower())[0]
    return camera_serial

def calculate_bgnd_from_masked_fulldata(masked_image_file):
    - Opens the masked_image_file hdf5 file, reads the /full_data node and 
      creates a "background" by taking the maximum value of each pixel over time.
    - Parses the file name to find a camera serial number
    - reads the pixel/um ratio from the masked_image_file
    import numpy as np
    from tierpsy.helper.params import read_unit_conversions

    # read attributes of masked_image_file
    _, (microns_per_pixel, xy_units) , is_light_background = read_unit_conversions(masked_image_file)
    # get "background" and px2um
    with pd.HDFStore(masked_image_file, 'r') as fid:
        assert is_light_background, \
        'MultiWell recognition is only available for brightfield at the moment'
        img = np.max(fid.get_node('/full_data'), axis=0)
    camera_serial = parse_camera_serial(masked_image_file)
    return img, camera_serial, microns_per_pixel

def make_square_template(n_pxls=150, rel_width=0.8, blurring=0.1, dtype_out='float'):
    import numpy as np
    """Function that creates a template that approximates a square well"""
    n_pxls = int(np.round(n_pxls))
    x = np.linspace(-0.5, 0.5, n_pxls)
    y = np.linspace(-0.5, 0.5, n_pxls)
    xx, yy = np.meshgrid(x, y, sparse=False, indexing='ij')

    # inspired by Mark Shattuck's function to make a colloid's template
    zz = (1 - np.tanh( (abs(xx)-rel_width/2)/blurring ))
    zz = zz * (1-np.tanh( (abs(yy)-rel_width/2)/blurring ))
    zz = zz/4
    # add bright border
    edge = int(0.05 * n_pxls)
    zz[:edge,:] = 1
    zz[-edge:,:] = 1
    zz[:,:edge] = 1
    zz[:,-edge:] = 1
    if dtype_out == 'uint8':
        zz *= 255
        zz = zz.astype(np.uint8)
    elif dtype_out == 'float':
        raise ValueError("Only 'float' and 'uint8' are valid dtypes for this")
    return zz

def was_fov_split(timeseries_data):
    Check if the FOV was split, looking at timeseries_data
    if 'well_name' not in timeseries_data.columns:
        # for some weird reason, save_feats_stats is being called on an old 
        # featuresN file without calling save_timeseries_feats_table first
        is_fov_split = False
        # timeseries_data has been updated and now has a well_name column
        if len(set(timeseries_data['well_name']) - set(['n/a'])) > 0:
            is_fov_split = True
#            print('have to split fov by well')
            assert all(timeseries_data['well_name']=='n/a'), \
                'Something is wrong with well naming - go check save_feats_stats'
            is_fov_split = False
    return is_fov_split

def naive_normalise(img):
    m = img.min()
    M = img.max()
    return (img - m) / (M-m)

def fft_convolve2d(x,y):
    """ 2D convolution, using FFT"""
    fr = fft2(x)
    fr2 = fft2(y)
    cc = np.real(ifft2(fr*fr2))
    cc = fftshift(cc)
    return cc

def simulate_wells_lattice(img_shape, x_off, y_off, sp, nwells=None, template_shape='square'): 
    Create mock fov by placing well templates onto a square lattice
    Very simply uses the input parameters and range to define where the wells 
    will go, and then places the template in a padded canvas. 
    The canvas is then cut to be of img_shape again.
    This simple approach works because the template is created to be exactly 
    spacing large, so templates do not overlap
    # convert fractions into integers
    x_offset = int(x_off*img_shape[0])
    y_offset = int(y_off*img_shape[0])
    spacing = int(sp*img_shape[0])

    # create a padded empty canvas 
    padding = img_shape[0]//2
    padding_times_2 = padding*2
    padded_shape = tuple(s+padding_times_2 for s in img_shape)
    padded_canvas = np.zeros(padded_shape)

    # determine where the wells wil go in the padded canvas
    if nwells is not None:
        r_wells = range(y_offset+padding,
        c_wells = range(x_offset+padding,
        r_wells = range(y_offset+padding,
        c_wells = range(x_offset+padding,
    tmpl_pos_in_padded_canvas = [(r,c) for r in r_wells for c in c_wells]

    # make the template for the wells
    tmpl = make_square_template(n_pxls=spacing, 
    # invert
    tmpl = 1-tmpl

    # place wells onto canvas
    ts = tmpl.shape[0]
    for r,c in tmpl_pos_in_padded_canvas:
                          c-ts//2:c-(-ts//2)] += tmpl
        except Exception as e:
            import pdb

    cutout_canvas = padded_canvas[padding:padding+img_shape[0],
    cutout_canvas = naive_normalise(cutout_canvas)
    return cutout_canvas

if __name__ == '__main__':
    # test that camera serials return the correct channel
    serials_list = [line[0] for line in CAM2CH_list]
#    serials_list.append('22594540') # this raise an exception as it does not exist
    for serial in serials_list:
        print('{} -> {}'.format(serial, serial2channel(serial)))
    # that works as intended! 
    # let's now check that the camera name is parsed correctly I guess
    from pathlib import Path
    src_dir = Path('/Users/lferiani/Desktop/Data_FOVsplitter/evgeny/MaskedVideos/20190808')
    masked_fnames = src_dir.rglob('*.hdf5')
    for fname in masked_fnames:
        camera_serial = parse_camera_serial(fname)
        print(' ')
    # this too works perfectly... but I saw wrong data was written in the masked videos
    # so have to check what went wrong there