import sys, os, glob, json

from typing import Optional, List

import numpy as np
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from source.utils import audio_to_frames

def get_model_name(path:str):

    if os.path.isdir(path):
        path = tf.train.get_checkpoint_state(path).model_checkpoint_path
    return path


def print_checkpoint(path:str, name:str):

    path = get_model_name(path)    
    reader = pywrap_tensorflow.NewCheckpointReader(path)
    if not name:
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in sorted(var_to_shape_map):
            tensor = reader.get_tensor(key)
            print('{}:{}'.format(key, tensor.shape))
            print(tensor)    
    else:
        print(name)
        print(reader.get_tensor(name))


def print_graph():
    
    vars = tf.global_variables()
    for var in vars:        
        print('{}:{}'.format(var.name, var.eval().shape))
        print(var.eval())


def get_var_from_checkpoint(path:str, tensor_name:str):

    path = get_model_name(path)
    reader = pywrap_tensorflow.NewCheckpointReader(path)
    return reader.get_tensor(tensor_name)


def get_var_from_graph(name:str):

    vars = tf.global_variables()
    for var in vars:
        if var.name == name:
            return var

    return None


def update_var_in_graph(sess:tf.Session, name:str, value:np.ndarray):

    var = get_var_from_graph(name)
    sess.run(var.assign(value))


def update_var_from_checkpoint(sess:tf.Session, name_to:str, name_from:str, path:str):

    path = get_model_name(path)    
    var_from = get_var_from_checkpoint(path, name_from)
    update_var_in_graph(sess, name_to, var_from)
    
      
def predict_from_checkpoint(audio:np.ndarray, checkpoint_dir:str, additional_layer_names=None, n_batch=1) -> List:        

    result = None    

    checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
    if ckeckpoint_path:
           
        graph = tf.Graph()

        with graph.as_default():

            saver = tf.train.import_meta_graph(checkpoint_path + '.meta')
            with open(os.path.join(checkpoint_dir, 'vocab.json'), 'r') as fp:
                vocab = json.load(fp)
    
            x = graph.get_tensor_by_name(vocab['x'])
            y = graph.get_tensor_by_name(vocab['y'])            
            init = graph.get_operation_by_name(vocab['init'])
            logits = graph.get_tensor_by_name(vocab['logits'])            
            ph_n_shuffle = graph.get_tensor_by_name(vocab['n_shuffle'])
            ph_n_repeat = graph.get_tensor_by_name(vocab['n_repeat'])
            ph_n_batch = graph.get_tensor_by_name(vocab['n_batch'])

            layers = [logits]     
            if additional_layer_names:       
                for layer_name in additional_layer_names:
                    layers.append(graph.get_tensor_by_name(layer_name))                
            result = [np.empty([0] + x.shape[1:].as_list(), dtype=np.float32) for x in layers]           

            frames = audio_to_frames(audio, x.shape[1], None) 
            labels = np.zeros((frames.shape[0],), dtype=np.int32)  
           
            with tf.Session() as sess:

                saver.restore(sess, checkpoint_path)
                sess.run(init, feed_dict = { 
                    x : frames, 
                    y : labels, 
                    ph_n_shuffle : 1,
                    ph_n_repeat : 1,
                    ph_n_batch : n_batch if n_batch > 0 else frames.shape[0]
                })
        
                count = 0
                while True:
                    try:                    
                        outputs = sess.run(layers)
                        for i, output in enumerate(outputs):
                            result[i] = np.concatenate([result[i], output])
                        #labels[count:count+output.shape[0]] = np.argmax(output, axis=1)                                
                        #count += output.shape[0]
                    except tf.errors.OutOfRangeError:                                                                                
                        break                

    return result



if __name__ == '__main__':

    path = r'..\test\ckpt'
    print_checkpoint(path, None)