import os
import matplotlib.pyplot as plt
import numpy as np
from keras.callbacks import LambdaCallback

# global for closure, gross
current_epoch = 0

# returns a keras callback
def image_saver_callback(model, directory, epoch_interval=1, batch_interval=100, cmap='gray', render_videos=False):
    
    def save_image(weights, batch, layer_name, i):
        global current_epoch
        weight = str(i + 1).zfill(2)
        epoch = str(current_epoch).zfill(3)
        fold = os.path.join(directory, 'epoch_{}-layer_{}-weights_{}'.format(epoch, layer_name, weight))
        if not os.path.isdir(fold):
            os.makedirs(fold)
        name = os.path.join('{}'.format(fold),
                            '{}_{}x{}.png'.format(str(batch).zfill(9), 
                                                  weights.shape[0], weights.shape[1]))
        plt.imsave(name, weights, cmap=cmap)
    
    def save_weight_images(batch, logs):
        global current_epoch
        if current_epoch % epoch_interval == 0 and batch % batch_interval == 0:
            for layer in model.layers:
                if len(layer.get_weights()) > 0:
                    for i, weights in enumerate(layer.get_weights()):
                        if len(weights.shape) < 2:
                            weights = np.expand_dims(weights, axis=0)
                        save_image(weights, batch, layer.name, i)
    
    def on_epoch_begin(epoch, logs):
        global current_epoch
        current_epoch = epoch

    def on_train_end(logs):
        src = os.path.dirname(os.path.abspath(__file__))
        cmd = os.path.join(src, '..', 'bin', 'create_image_sequence.sh')
        print(os.system('{} {}'.format(cmd, directory)))

    kwargs = dict()
    kwargs['on_batch_begin'] = save_weight_images
    kwargs['on_epoch_begin'] = on_epoch_begin
    if render_videos:
        kwargs['on_train_end'] = on_train_end

    return LambdaCallback(**kwargs)