#! /usr/bin/env python3
# ^ Note: this code uses f-strings, which require python 3.6 and higher

# Purpose:
#  This checks for inconsistencies between input-output pairs in datasets (in bulk)
#
#  Author:
#   Scott H. Hawley (scott.hawley@belmont.edu),  May 12, 2020
#
# Methodology:
#   It checks the lengths of the audio, and then performs a cross correlation
#   to measure timing 'skew' or offset.
#
# Usage:
#    $ ./check_timing.py [options] <input> <target>
#  or
#    $ ./check_timing.py [options] <directory>
#  or
#    $ ./check_timing.py [options] <input1> <input2>... <target1> <target2>...
##
#    If <directory> is specified, it will search for an 'input' file
#    and match it with a corresponding 'target' file with the name number, e.g.
#    "input_47_" and "target_47_".
#    If there are subdirectories in <directory> (e.g. Train/ and Test/), it will
#    descend into those.  Watch out if you have symbolic directory links inside
#     your main dataset.
#
# Command-line options:
#    By default it simply outputs analysis data, and takes no actions to fix anything.
#    The following "fix" options will all make changes to the dataset *in place*.
#     meaning the existing dataset is *overwritten*. Thus it is recommended that
#     you only run this script on a *copy* of your dataset,
#    Run the script first without them to see what it's going to do.

#  -a    (Time-)Align the audio, using cross-correlation.  (Don't use this if
#            you've got an 'echo' effect!)
#  -d    Delete any extra files, i.e. input files witout corresponding target outputs,
#            or vice versa.
#  -l    Fix the length: truncate any extra audio appearing in one file but not
#            the other.  (Runs after -a)
#  -m    Force mono (who knows, maybe somebody made a stereo file?)
#  -s    Sample rate; in this case it will force all the files have the same sample rate
#          as that of the first input file it encounters (whatever that is)
#  --fix:  All of the above, i.e. "--fix" == "-adlms"
#    Again, you may not want all the --fix options, so...use with care.
#
#  Example:
#   $ ./check_dataset.py datasets/SignalTrain_LA2A_dataset_rev1/Train/*.wav
# or
#   $ ./check_dataset.py LA2A_Dataset/           # to check
#  then
#   $ ./check_dataset.py --fix LA2A_Dataset/     # to fix everything

from scipy import signal
import numpy as np
import argparse
import sys
import os
import glob
import librosa
from scipy.io import wavfile
import shutil
import re

DEBUG = False

if DEBUG:
    import matplotlib
    matplotlib.use('TkAgg')   # use a raster backend for plotting many points
    import matplotlib.pyplot as plt


class colors():    # Because I'm lazy
    BLACK = '\033[30m'
    RED = '\033[31m'
    GREEN = '\033[32m'
    YELLOW = '\033[33m'
    BLUE = '\033[34m'
    MAGENTA = '\033[35m'
    CYAN = '\033[36m'
    WHITE = '\033[37m'
    UNDERLINE = '\033[4m'
    RESET = '\033[0m'


def estimate_time_shift(x, y):
    """ Computes the cross-correlation between time series x and y, grabs the
        index of where it's a maximum.  This yields the time difference in
        samples between x and y.
    """
    if DEBUG: print("computing cross-correlation")
    corr = signal.correlate(y, x, mode='same', method='fft')
    if DEBUG: print("finished computing cross-correlation")

    nx, ny = len(x), len(y)
    t_samples = np.arange(nx)
    ct_samples = t_samples - nx//2  # try to center time shift (x axis) on zero
    cmax_ind = np.argmax(corr)      # where is the max of the cross-correlation?
    dt = ct_samples[cmax_ind]       # grab the time shift value corresponding to the max c-corr

    if DEBUG:
        print("cmax_ind, nx//2, ny//2, dt =",cmax_ind, nx//2, ny//2, dt)
        fig, (ax_x, ax_y, ax_corr) = plt.subplots(3, 1)
        ax_x.get_shared_x_axes().join(ax_x, ax_y)
        ax_x.plot(t_samples, x)
        ax_y.plot(t_samples, y)
        ax_corr.plot(ct_samples, corr)
        plt.show()
        
    return dt


#  for use in filtering filenames
def is_acceptable(filename):
    return filename.lower().endswith(('.wav', '.mp3', '.aif', '.aiff')) and \
        (('input_' in filename) or ('target_' in filename))



parser = argparse.ArgumentParser(description="Check dataset for mismatches",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('input_or_dir', help='input file 1, or directory')
parser.add_argument('target_or_more_files', nargs='*', help='target file 1, or optional more files (for non-directory usage)')
parser.add_argument('-a','--align', help='Fix: Align time (overwrites)', action='store_true')
parser.add_argument('-d','--delete', help='Fix: Delete extra/unmatching input or target files (overwrites)', action='store_true')
parser.add_argument('-f','--fast', help='Fast: skip timing checks', action='store_true')
parser.add_argument('-l','--length', help='Fix: Make lengths the same, by truncating (overwrites)', action='store_true')
parser.add_argument('-m','--mono', help='Fix: Force mono (overwrites)', action='store_true')
parser.add_argument('-s','--sr', help='Fix: Enforce sample rate of first input (overwrites)', action='store_true')
parser.add_argument('--fix', help='Fix: Apply all fixes (overwrites)', action='store_true')
args = parser.parse_args()
if (args.fix):
    [args.align, args.length, args.delete, args.sr, args.mono] = [True]*5
if DEBUG: print("args =",args)

# Make sense of how the user is specifying where to check
if args.target_or_more_files == []:
    dir = args.input_or_dir
    assert os.path.isdir(dir), f"{dir} is not a directory"

    print(f"Operating on directory {dir}")
    file_list, input_list, target_list = [], [], []
    # TODO: make sure it's actually a directory

    for dirpath, subdirs, files in os.walk(dir):
        for f in files:
            if f.lower().endswith(('.wav', '.mp3', '.aif', '.aiff')):
                if 'input' in f:
                    input_list.append(os.path.join(dirpath, f))
                elif 'target' in f:
                    target_list.append(os.path.join(dirpath, f))
            if is_acceptable(f):
                file_list.append(os.path.join(dirpath, f))
else:
    file_list = [args.input_or_dir] + args.target_or_more_files
    print(f"Operating on a list of {len(file_list)} files")
    # make a list of all the inputs, and a list of all the tagets
    input_list = list(filter(lambda x: 'input' in x, file_list))
    target_list = list(filter(lambda x: 'target' in x, file_list))

input_list.sort()
target_list.sort()


print("\n#### SIMPLE SANITY CHECKS based on filenames. Fast")

# sanity check: as many inputs as targets?
#   Note: one could imagine multiple targets for the same input, but we've
#   not done that for signaltrain.
ni, nt = len(input_list), len(target_list)

# TODO: make it tell us specifically what's lacking or extra
if ni != nt:
    print(f"{colors.RED}**PROBLEM**:{colors.RESET} {ni} inputs but {nt} targets")
    input_nums = [re.search('_[0-9]+_', os.path.basename(i)).group() for i in input_list]
    target_nums = [re.search('_[0-9]+_', os.path.basename(i)).group() for i in target_list]

    for i in input_nums:  # TODO: slow. make this faster with pythonic list operations
        if not (i in target_nums):
            print(f'  {i} is in inputs but not targets')
    for i in target_nums:
        if not (i in input_nums):
            print(f'  {i} is in targets but not inputs')
    sys.exit(1)

# total list of files
file_list = input_list + target_list

# Show what we'll be checking
if DEBUG: print("file_list = ",file_list)


# make sure same file doesn't exist in multiple directories
basenames = [os.path.basename(p) for p in file_list]   # grab all the filenames
assert len(basenames) == len(set(basenames)), "You've got duplicates"

# Loop through files
for i in range(ni):
    problem = False
    input_filename, target_filename = input_list[i], target_list[i]
    ibase, tbase = os.path.basename(input_filename), os.path.basename(target_filename)
    #print(f"input = {input_filename},    target = {target_filename}")

    # make sure the first is an input and the second is a target
    assert ('input_' in ibase) and ('target_' in tbase)

    # make sure the number-designation (first numbers found) of the files line up
    input_num = re.search('_[0-9]+_', ibase).group()
    target_num = re.search('_[0-9]+_', tbase).group()
    if input_num != target_num:
        print(f"{colors.RED}    **PROBLEM**:{colors.RESET} For input = {input_filename},  target = {target_filename}:")
        print(f"                 input_num ({input_num}) != target_num ({target_num})")
        sys.exit(1)
    # make sure they're in the same directory
    assert os.path.dirname(input_filename) == os.path.dirname(target_filename)



print("#### CHECKING THE AUDIO.  Slower.")
# Loop through files
for i in range(ni):
    problem = False
    input_filename, target_filename = input_list[i], target_list[i]
    ibase, tbase = os.path.basename(input_filename), os.path.basename(target_filename)
    print(f"input = {input_filename},    target = {target_filename}")


    repaired = False     # flag for if we want to output a fixed set of files

    # Read the audio files.  x, y = data for input, target
    x, sr_x = librosa.load(input_filename, sr=None, mono=False)
    y, sr_y = librosa.load(target_filename, sr=None, mono=False)

    # Check basic stuff
    if sr_x != sr_y:
        print(f"{colors.RED}    **PROBLEM**:sr_x ({sr_x}) != sr_y ({sr_y}){colors.RESET}")
        if args.sr:
            sr_y, repaired = sr_x, True
            print("     Fixing: setting sr_y := sr_x")
        else:
            problem = True

    if x.shape != y.shape:
        print(f"{colors.RED}    **PROBLEM**: x.shape ({x.shape}) != y.shape ({y.shape}){colors.RESET}")
        problem = True

    if args.mono:
        if len(x.shape) > 1:
            x = x[0,:]
            repaired = True
        if len(y.shape) > 1:
            y = y[0,:]
            repaired = True


    ### Check timing alignment.  Slow
    if not args.fast:
        #Compute the time delay (argmax of cross-correlation) in samples
        if DEBUG: print("    Calling estimate_time_shift (slow)")
        nx = len(x)
        short_len = nx//10       # 20-minute long audio files take a while, so use a subset
        dt = estimate_time_shift(x[0:short_len], y[0:short_len])
        if dt != 0:
            print(f"{colors.RED}    **PROBLEM**: Estimated time shift of {dt} samples from input to target.{colors.RESET}")
            problem = True

            if args.align:     # Fix the alignment
                print("        Trying to fix alignment...")
                if dt < 0:
                    x = x[-dt:]
                else:
                    y = y[dt:]
                newlen = min(x.shape[0], y.shape[0])
                x, y = x[0:newlen], y[0:newlen]

                # check to see how we did
                dt  = estimate_time_shift(x[0:short_len], y[0:short_len])
                print(f"        New estimated time shift = {dt} samples, x.shape = {x.shape}, y.shape = {y.shape}")
                if dt == 0:
                    problem, repaired = False, True
                else:
                    assert False, "Can't figure out what to do with this."

    if not problem:
        print(f" {colors.GREEN}  Looks good! :-) {colors.RESET}")

    if repaired: # save -- overwrite -- new versions of input & output
        print("       Overwriting new version of input and target...")
        wavfile.write(input_filename, sr_x, x)
        wavfile.write(target_filename, sr_y, y)