import matplotlib
import tensorflow as tf
import numpy as np
import re, os
import random
import math
import json
import matplotlib.pyplot as plt
import sklearn.metrics
from seaborn import barplot, set_style
from sklearn.preprocessing import OneHotEncoder
from collections import OrderedDict

class OptionHandler():
    def __init__(self, config_dict):
        #self._usefp16 = config_dict['usefp16']
        self.config = config_dict
        self._name = config_dict['model_name']
        self._thresholdnum = config_dict['threshold_num']
        self._gpu = config_dict['gpu']
        self._allowsoftplacement = config_dict['allow_softplacement']
        self._numepochs = config_dict['num_epochs']
        self._numsteps = config_dict['num_steps']
        self._kmer2vec_embedding = config_dict['kmer2vec_embedding']
        self._kmer2vec_kmerdict = config_dict['kmer2vec_kmerdict']
        self._kmer2vec_kmercounts = config_dict['kmer2vec_kmercounts']
        self._kmersize = config_dict['kmer_size']
        self._nkmers = config_dict['n_kmers']
        self._embeddingdim = config_dict['embedding_dim']
        self._depth = config_dict['depth']
        self._structuredims = config_dict['structure_dims']
        self._traindata = config_dict['train_data']
        self._validdata = config_dict['valid_data']
        self._batchesdir = config_dict['batches_dir']
        self._inferencedata = config_dict['inference_data']
        self._inferencemode = True if config_dict['inference_mode'] == 'True' else False
        self._labels = config_dict['labels']
        self._nclasses = config_dict['n_classes']
        self._topk = config_dict['topk']
        self._classbalancing = True if config_dict['class_balancing'] == 'True' else False
        self._maxclassinbalance = config_dict['maxclass_inbalance']
        self._dropoutrate = config_dict['dropoutrate']
        self._learningrate = config_dict['learning_rate']
        self._epsilon = config_dict['epsilon']
        self._batchsize = config_dict['batch_size']
        self._batchgenmode = config_dict['batchgen_mode']
        self._windowlength = config_dict['window_length']
        self._minlength = config_dict['min_length']
        self._numthreads = config_dict['num_threads'] #TODO ASSERT THIS NUMBER!!!!!!!!
        self._restorepath = config_dict['restore_path']
        self._restore = True if config_dict['restore'] == 'True' else False
        self._debug = True if config_dict['debug'] == 'True' else False
        self._ecfile = config_dict['EC_file']
        self._summariesdir = config_dict['summaries_dir'] # for tensorboard
        self._summariesdir = self._summariesdir + '_{l}_{n}_{w}_{g}_{b}_{lr}_{e}'.format(g=self._batchgenmode,
        self._seqfile = config_dict['seqfile']
        self._survivalpop = config_dict['survival_pop']
        self._generations = config_dict['generations']
        self._systematic = config_dict['systematic']
        self._muts_per_gen = config_dict['muts_per_gen']
        self._decrease_muts_after_gen = config_dict['decrease_muts_after_gen']

        if not os.path.exists(self._summariesdir):
        if not os.path.exists(self._batchesdir):

    def write_dict(self):
        Store the config_dict to disc in the save_dir
        with open(os.path.join(self._summariesdir, 'config_dict.JSON'), "w") as config_dict:
            json.dump(self.config, config_dict)

class RocTracker():
    def __init__(self, optionhandler):
        self._opts = optionhandler
        self.metrics_path = os.path.join(self._opts._summariesdir, 'metrics')
        if not os.path.exists(self.metrics_path):
        self.metrics_file = open(os.path.join(self.metrics_path, 'metrics.csv'), "w")
        self.roc_score = []
        self.roc_labels = []
        self.pred_positives_sum = np.zeros(self._opts._nclasses)
        self.actual_positives_sum = np.zeros(self._opts._nclasses)
        self.true_positive_sum = np.zeros(self._opts._nclasses)
        self.num_calculations = 0

    def update(self, sigmoid_logits, true_labels):
        update the ROC tracker, with the predictions on one batch made during validation
        # threshold this thing
        # we consider a class "predicted" if it's sigmoid activation is higher than 0.5 (predicted labels)
        batch_predicted_labels = np.greater(sigmoid_logits, 0.5)
        batch_predicted_labels = batch_predicted_labels.astype(float)

        batch_pred_pos = np.sum(batch_predicted_labels, axis=0) #sum up along the batch dim, keep the channels
        batch_actual_pos = np.sum(true_labels, axis=0) #sum up along the batch dim, keep the channels
        # calculate the true positives:
        batch_true_pos = np.sum(np.multiply(batch_pred_pos, batch_actual_pos), axis=0)

        # and update the counts
        self.pred_positives_sum += batch_pred_pos #what the model said
        self.actual_positives_sum += batch_actual_pos #what the labels say
        self.true_positive_sum += batch_true_pos # where labels and model predictions>0.5 match

        assert len(self.true_positive_sum) == self._opts._nclasses

        # add the predictions to the roc_score tracker

    def calc_and_save(self, logfile):
        Calculate the ROC curve with AUC value for the collected test values (roc_scores, roc_labels).
        Writes everything to files, plots curves and resets the Counters afterwards.
        self.metrics_file = open(os.path.join(self.metrics_path, 'metrics.csv'), "w")

        self.num_calculations += 1

        # concat score and labels along the batchdim -> a giant test batch
        self.roc_score = np.concatenate(self.roc_score, axis=0)
        self.roc_labels = np.concatenate(self.roc_labels, axis=0)

        # get the total number of seqs we tested on:
        logfile.write('[*] Calculating metrics\n')
        test_set_size = self.roc_labels.shape[0]

        # do the calculations
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true=np.reshape(self.roc_labels, (-1)),
                                                         y_score=np.reshape(self.roc_score, (-1)))
        auc = sklearn.metrics.auc(fpr, tpr)

        precision_arr, recall_arr, thresholds = sklearn.metrics.precision_recall_curve(
            y_true=np.reshape(self.roc_labels, (-1)), probas_pred=np.reshape(self.roc_score, (-1))) # micro-average PR curve

        # now save everything to the metrics.csv
        # metrics = [fpr, tpr, auc,
        #            thresholds, precision_arr, recall_arr, thresholds]
        # metrics_as_str = []
        # for m in metrics:
        #     if isinstance(m, np.ndarray):
        #         m_str = ','.join(str(e) for e in m.tolist())
        #     else:
        #         m_str = str(m)
        #     metrics_as_str.append(m_str)
        # line = ';'.join(metrics_as_str)
        # line += '\n'
        # self.metrics_file.write(line)

        # write get the max, min and avg scores for each class:
        # determine the scores for the labels
        scores = self.roc_score * self.roc_labels

        mean_scores = np.mean(scores, axis=0)
        assert mean_scores.shape[0] == self._opts._nclasses
        max_scores = np.amax(scores, axis=0)
        assert max_scores.shape[0] == self._opts._nclasses
        min_scores = np.amin(scores, axis=0)
        assert min_scores.shape[0] == self._opts._nclasses

        self.metrics_file.write(str(mean_scores) + '\n')
        self.metrics_file.write(str(max_scores) + '\n')
        self.metrics_file.write(str(min_scores) + '\n')


        # get printable metrics (for log file)
        precision_class = self.true_positive_sum / np.maximum(1, self.pred_positives_sum) # where predPositives_sum == 0, tp_sum is also 0
        recall_class = self.true_positive_sum / np.maximum(1, self.actual_positives_sum) # where actualPositives_sum == 0, tp_sum is also 0
        precision = np.sum(self.true_positive_sum) / np.sum(self.pred_positives_sum)
        recall = np.sum(self.true_positive_sum) / np.sum(self.actual_positives_sum)
        f1 = 2*precision*recall / (precision + recall)
        logfile.write("[*] Tested on %d seqs, "
                      "precision %.2f%%, "
                      "recall %.2f%%, "
                      "F1 %.2f%%\n" % (test_set_size, precision, recall, f1))

        #plot ROC:
        plot_simple_curve(x=fpr, y=tpr, title=self._opts._name + '_ROC_curve',
                          legend=self._opts._name + ' (AUC = %0.4f)' % auc,
                          xname='False positive rate', yname='True positive rate',
                          filename=os.path.join(self.metrics_path, self._opts._name + '.roc_%d' % self.num_calculations))

        # PR curve
        plot_simple_curve(x=recall_arr, y=precision_arr,
                          title=self._opts._name + ' PR curve', legend=self._opts._name,
                          xname='Recall', yname='Precision',
                          filename=os.path.join(self.metrics_path, self._opts._name + '.precision_%d' % self.num_calculations))

        # reset the stats-collectors:
        self.roc_score = []
        self.roc_labels = []
        self.pred_positives_sum = np.zeros(self._opts._nclasses)
        self.actual_positives_sum = np.zeros(self._opts._nclasses)

        logfile.write('[*] Done testing.\n')

class StratifiedCounterDict(dict):
    def __missing__(self, key):
        self[key] = {'tp': 0,
                     'pred_p': 0,
        return self[key]

class BatchGenerator():
    def __init__(self, optionhandler, kmer2vec_embedding, kmer2id):
        self._opts = optionhandler
        self.mode = self._opts._batchgenmode # one of ['window', 'bigbox', 'dynamic']
        self._kmer2vec_embedding = kmer2vec_embedding
        self._kmer2id = kmer2id
        self.inferencedata = open(self._opts._inferencedata, 'r')
        self.traindata = open(self._opts._traindata, 'r')
        self.validdata = open(self._opts._validdata, 'r')
        self.AA_to_id = {}
        self.id_to_AA = {}
        self.class_dict = OrderedDict()
        self.id_to_class = OrderedDict()
        self.embedding_dict = OrderedDict()
        # determine the number of batches for eval from lines in the validdata and the garbagepercentage
        self.garbage_percentage = 0.2
        self.garbage_count = 0 # a counter for generated garbage sequences
        self.eval_batch_nr = int(_count_lines(self._opts._validdata) * (1 + self.garbage_percentage) //
        print('Initialized Batchgen with   batchsize: %d,   numeval_batches: %d at'
              '                            garbage_percentage: %f' % (self._opts._batchsize,
        self.batches_per_file = 10000
        self.epochs = 2000
        self.curr_epoch = 0
        self.label_enc = OneHotEncoder(n_values=self._opts._nclasses, sparse=False)
        self.AA_enc = 'where we put the encoder for the AAs'

        if self.mode.startswith('one_hot'):
            AAs = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                   'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',]

            self.AA_enc = OneHotEncoder(n_values=self._opts._depth, sparse=False)

            if 'physchem' in self.mode:
                _hydro = [1.8, 2.5, -3.5, -3.5, 2.8,
                          -0.4, -3.2, 4.5, -3.9, 3.8,
                          1.9, -3.5, -1.6, -3.5, -4.5,
                          -0.8, -0.7, 4.2, -0.9, -1.3]
                _molarweight = [89.094, 121.154, 133.104, 147.131, 165.192,
                                75.067, 155.156, 131.175, 146.189, 131.175,
                                149.208, 132.119, 115.132, 146.146, 174.203,
                                105.093, 119.119, 117.148, 204.228, 181.191]
                _is_polar = lambda aa: 1 if aa in ['DEHKNQRSTY'] else 0
                _is_aromatic = lambda aa: 1 if aa in ['FWY'] else 0
                _has_hydroxyl = lambda aa: 1 if aa in ['ST'] else 0 #should we add TYR??
                _has_sulfur = lambda aa: 1 if aa in ['CM'] else 0

                for i, aa in enumerate(AAs):
                    self.AA_to_id[aa]  = {'id': len(self.AA_to_id),
                                          'hydro': _hydro[i],
                                          'molweight': _molarweight[i],
                                          'pol': _is_polar(aa),
                                          'arom': _is_aromatic(aa),
                                          'sulf': _has_sulfur(aa),
                                          'OH': _has_hydroxyl(aa)}
                for aa in AAs:
                    self.AA_to_id[aa] = len(self.AA_to_id)
                # get the inverse:
                self.id_to_AA = {}
                for aa, id in self.AA_to_id.items():
                    self.id_to_AA[id] = aa
                self.id_to_AA[42] = '_'

    def _get_class_dict(self):
        with open(self._opts._ecfile, "r") as ec_fobj:
            for line in ec_fobj:
                fields = line.strip().split()
                if fields[1].endswith('.csv'):  #TODO delete this when error is fixed
                    fields[1] = fields[1].rstrip('.csv')

                if self._opts._labels == 'EC':
                    self.class_dict[fields[1]] = {'id': len(self.class_dict),
                                                  'size': int(fields[0]),
                if self._opts._labels == 'GO':
                    self.class_dict[fields[1].split('_')[1]] = {'id': len(self.class_dict),
                                                                'size': int(fields[0]),

        # get a reverse dict:
        for key in self.class_dict.keys():
            self.id_to_class[self.class_dict[key]['id']] = key

    def _update_embedding_dict(self, name, labels):
        Update the embedding dict for new entries. This is used on the fly as we perform inference
        if len(self.embedding_dict) == 0:
            # add UNK token the first time this method is called
            self.embedding_dict['UNK'] = {}
            self.embedding_dict['UNK']['labels'] = ['UNK']
            self.embedding_dict['UNK']['id'] = 0 # we save 0 for the UNK token

        # check if the key (= name) is already in the dict:
        if name not in self.embedding_dict:
            assert len(self.embedding_dict) > 0
            self.embedding_dict[name] = {}
            self.embedding_dict[name]['labels'] = labels
            self.embedding_dict[name]['id'] = len(self.embedding_dict)
            print('WARNING: Overwrote value in embedding dict. Check infile for redundant sequences!')

    def _csv_EC_decoder(self, in_csv, encoded_labels=True):
        line = in_csv.readline()
        fields = line.strip().split(';')
        name = fields[0]
        seq = fields[1]
        if fields[2].endswith('.csv'): #TODO assert this
            fields[2] = fields[2].rstrip('.csv')
        if self._opts._labels == 'EC':
            EC_str = fields[2] #TODO assert this
            if encoded_labels:
                EC_CLASS = 0 if self._opts._inferencemode else self.class_dict[EC_str]['id']
                label = [[EC_CLASS]] # we need a 2D array
                label = EC_str
        elif self._opts._labels == 'GO':
            GO_str = fields[2] #TODO assert this
            GOs = 0 if self._opts._inferencemode else fields[2].split(',') #TODO assert this
            if encoded_labels:
                label = [[self.class_dict[go]['id']] for go in GOs] # returns a 2D array
                label = GOs
        # TODO add an assertion for mode
        return name, seq, label

    def _seq2tensor(self, seq):
        Does what you think it does.

        As for now we Fix the considered length to 200AA position,
        yielding a tensor of shape:
        [100, 196, 1]
        if self.mode == 'one_hot_padded':
            # first check if the sequence fits in the box:
            if len(seq) <= self._opts._windowlength:
                seq_matrix = np.ndarray(shape=(len(seq)), dtype=np.int32)
            # if sequence does not fit we clip it:
                seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            for i in range(len(seq_matrix)):
                seq_matrix[i] = self.AA_to_id[seq[i]]
            start_pos = 0 #because our sequence sits at the beginning of the box
            length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), 20))
            # pad the sequence to the boxsize:
            npad = ((0, self._opts._windowlength-length), (0, 0))
            padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            padded_seq_matrix = np.transpose(padded_seq_matrix)
            del oh_seq_matrix, seq_matrix

            # seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            # for i in range(len(seq_matrix)):
            #     seq_matrix[i] = self.AA_to_id[seq[i]]
            # start_pos = 0 #because our sequence sits at the beginning of the box
            # length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            # oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), self._opts._depth))
            # pad the sequence to the boxsize:
            # npad = ((0, self._opts._windowlength-length), (0, 0))
            # padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            # padded_seq_matrix = np.transpose(padded_seq_matrix)
            # del oh_seq_matrix, seq_matrix

            return padded_seq_matrix, start_pos, length #true length 1 based

        elif self.mode == 'one_hot_padded_physchem':
            # TODO implement this shit

            print("Error: MODE must be of ['one_hot_padded', 'one_hot_padded_physchem']")

    def _encode_single_seq(self, seq, desired_label=None):
        Encode single sequence.
        seq_matrix, start_pos, length = self._seq2tensor(seq)
        # look up the label in the class_dict:
        if desired_label:
            desired_label_ID = self.class_dict[desired_label]['id']

            # encode label one_hot:
            oh_label = self.label_enc.fit_transform([[desired_label_ID]]) # of shape [1, n_classes]
            return oh_label, seq_matrix, start_pos, length

            return seq_matrix

    def _process_csv(self, queue, return_name=True, encode_labels=True):
        pls infer from name.
        name, seq, label = self._csv_EC_decoder(queue, encoded_labels=encode_labels)
        seq_matrix, start_pos, end_pos = self._seq2tensor(seq)
        if return_name:
            return name, label, seq_matrix, start_pos, end_pos
            return label, seq_matrix, start_pos, end_pos

    def generate_garbage_sequence(self, return_name=False):
        Generates a sequence full of garbage, e.g. a obviously non functional sequence.
        modes = ['complete_random', 'pattern', 'same']
        mode = modes[random.randint(0, 2)]
        self.garbage_count += 1

        # get the length of the protein
        length = random.randint(175, self._opts._windowlength-10) #enforce padding

        if mode == 'pattern':
            # Generate a repetitive pattern of 5 AminoAcids to generate the prot
            # get a random nr of AAs to generate the pattern:
            AA_nr = random.randint(2, 5)
            # get an index for each AA in AA_nr
            idxs = []
            for aa in range(AA_nr):
                idx_found = False
                while not idx_found:
                    aa_idx = random.randint(0, 19)
                    if not aa_idx in idxs:
                        idx_found = True
            reps = math.ceil(length/AA_nr)
            seq = reps * idxs
            length = len(seq)

        elif mode == 'complete_random':
            # print('complete_random')
            seq = []
            for aa in range(length):
                # get an idx for every pos in length:
                idx = random.randint(0, 19)

        elif mode == 'same':
            # print('ONE')
            AA = random.randint(0, 19)
            seq = length * [AA]

        label = np.zeros([self._opts._nclasses])
        label = np.expand_dims(label, axis=0)
        garbage_label = np.asarray([1])
        garbage_label = np.expand_dims(garbage_label, axis=0)
        oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq, (1, -1))), (len(seq), 20))
        # pad the sequence to the boxsize:
        npad = ((0, self._opts._windowlength-length), (0, 0))
        padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
        padded_seq = np.transpose(padded_seq_matrix)
        if return_name:
            # return a sequence ID to identify the generated sequence
            # generate a "random" name
            name = 'g%d' % self.garbage_count
            return name, padded_seq, label, garbage_label
            return padded_seq, label, garbage_label

    def generate_random_data_batch(self):
        seq_tensor_batch = tf.random_normal([self._opts._batchsize, self._opts._embeddingdim, self._opts._windowlength, 1])

        label_batch = [np.random.randint(1,self._opts._nclasses) for _ in range(self._opts._batchsize)]
        index_batch = [tf.constant(label) for label in label_batch]
        label_tensor = tf.stack(index_batch)
        onehot_labelled_batch = tf.one_hot(indices=tf.cast(label_tensor, tf.int32),
        return seq_tensor_batch, onehot_labelled_batch

    def generate_single_seq_batch(self, seq, desired_label):
        Generate a batch from a single sequence. (means, the whole batch is the same seq) we need
        this as the network is made for a certain batchsize. If we restore the model we need to keep this batchsize constant.
        seq_tensors = []
        label_batch = []
        for _ in range(self._opts._batchsize):
            oh_label, seq_tensor, _, _ = self._encode_single_seq(seq, desired_label)
        batch_tensor = np.expand_dims(np.stack(seq_tensors, axis=0), axis=-1)
        label_tensor = np.concatenate(label_batch, axis=0)
        # drop the first dimension

        return batch_tensor, label_tensor

    def generate_inference_batch(self):
        Generates a batch to infer the labels for sequences, as everything is fed into the same graph,
        we use the same kind of preprocessing and basically the same function as generate_batch but on
        another file.
        :return: batch, labels (empty), positions
        seq_tensors = []
        label_batch = []
        positions = np.ndarray([self._opts._batchsize, 2])
        lengths = np.ndarray([self._opts._batchsize])
        in_csv = self.inferencedata
        for i in range(self._opts._batchsize):
                """ Note that this is not shuffled! """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,
            except IndexError: # catches error from csv_decoder
                # reopen the file:
                # TODO: implement file shuffling when we reopen the file
                self.inferencedata = open(self._opts._inferencedata, 'r')
                in_csv = self.inferencedata
                """ redo """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,

                positions[i, 0] = start_pos
                positions[i, 1] = end_pos
                lengths[i] = end_pos

        batch = np.stack(seq_tensors, axis=0)

        if 'spp' in self.mode:
            return batch, label_batch, lengths
        if 'padded' in self.mode:
            return batch, label_batch, lengths
            return batch, label_batch, positions

    def generate_batch(self, is_train):
            generate batches to train the model:
            as we use the sparse softmax ce, we DO NOT NEED TO ONE HOT ENCODE OUR LABELS!
        seq_tensors = []
        label_batch = []
        positions = np.ndarray([self._opts._batchsize, 2])
        lengths = np.ndarray([self._opts._batchsize])
        if is_train:
            in_csv = self.traindata
            in_csv = self.validdata
        for i in range(self._opts._batchsize):
                """ Note that this is not shuffled! """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,
            except IndexError: # catches error from csv_decoder
                # reopen the file:
                # TODO: implement file shuffling when we reopen the file
                if is_train:
                    self.traindata = open(self._opts._traindata, 'r')
                    in_csv = self.traindata
                    self.validdata = open(self._opts._validdata, 'r')
                    in_csv = self.validdata
                """ redo """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,

                positions[i, 0] = start_pos
                positions[i, 1] = end_pos
                lengths[i] = end_pos

        batch = np.stack(seq_tensors, axis=0)

        if 'spp' in self.mode:
            return batch, label_batch, lengths
        if 'padded' in self.mode:
            return batch, label_batch, lengths
            return batch, label_batch, positions

    def generate_valid_batch(self, include_garbage=False):
        Generates a batch to infer the labels for sequences, as everything is fed into the same graph,
        we use the same kind of preprocessing and basically the same function as generate_batch but on
        another file.
        :return: batch, labels (empty), positions
        seq_tensors = []
        in_csv = self.validdata
        if not include_garbage:
            for i in range(self._opts._batchsize):
                    """ Note that this is not shuffled! """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                    self._update_embedding_dict(name, label)
                except IndexError: # catches error from csv_decoder
                    # reopen the file:
                    # TODO: implement file shuffling when we reopen the file
                    self.validdata = open(self._opts._validdata, 'r')
                    in_csv = self.validdata
                    """ redo """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                    self._update_embedding_dict(name, label)

        elif include_garbage:
            num_garbage = math.ceil(self._opts._batchsize * self.garbage_percentage)
            for i in range(self._opts._batchsize - num_garbage):
                    """ Note that this is not shuffled! """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                    self._update_embedding_dict(name, label)
                except IndexError: # catches error from csv_decoder
                    # reopen the file:
                    # TODO: implement file shuffling when we reopen the file
                    self.validdata = open(self._opts._validdata, 'r')
                    in_csv = self.validdata
                    """ redo """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                    self._update_embedding_dict(name, label)

            for i in range(num_garbage):
                name, seq_tensor, _, _ = self.generate_garbage_sequence(return_name=True)
                label = 'garbage'
                self._update_embedding_dict(name, label)

        batch = np.stack(seq_tensors, axis=0)
        batch = np.expand_dims(batch, axis=-1)
        return batch

class TFrecords_generator():
    def __init__(self, optionhandler):
        self._opts = optionhandler
        self.label_enc = OneHotEncoder(n_values=self._opts._nclasses, sparse=False)
        self.AA_enc = 'where we put the encoder for the AAs'
        self.mode = self._opts._batchgenmode # one of ['window', 'bigbox', 'dynamic']
        self._kmer2vec_embedding = 'kmer2vec_embedding'
        self._kmer2id = {}
        self.inferencedata = open(self._opts._inferencedata, 'r')
        self.traindata = open(self._opts._traindata, 'r')
        self.validdata = open(self._opts._validdata, 'r')
        self.AA_to_id = {}
        self.class_dict = {}
        self.structure_dict = {}
        self.examples_per_file = 10000
        self.epochs = self._opts._numepochs
        self.curr_epoch = 0
        self.writer = 'where we put the writer'

        # get the structure_dict
        structure_forms = ['UNORDERED', 'HELIX', 'STRAND', 'TURN']
        assert len(structure_forms) == self._opts._structuredims-1
        for s in structure_forms:
            self.structure_dict[s] = len(self.structure_dict) + 1 #serve the 0 for NO INFORMATION
        #self.structure_enc = OneHotEncoder(n_values=self._opts._structuredims, sparse=False)

        if self.mode.startswith('one_hot'):
            AAs = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                   'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',]

            self.AA_enc = OneHotEncoder(n_values=self._opts._depth, sparse=False)
            if 'physchem' in self.mode:
                _hydro = [1.8, 2.5, -3.5, -3.5, 2.8,
                          -0.4, -3.2, 4.5, -3.9, 3.8,
                          1.9, -3.5, -1.6, -3.5, -4.5,
                          -0.8, -0.7, 4.2, -0.9, -1.3]
                _molarweight = [89.094, 121.154, 133.104, 147.131, 165.192,
                                75.067, 155.156, 131.175, 146.189, 131.175,
                                149.208, 132.119, 115.132, 146.146, 174.203,
                                105.093, 119.119, 117.148, 204.228, 181.191]
                _is_polar = lambda aa: 1 if aa in ['DEHKNQRSTY'] else 0
                _is_aromatic = lambda aa: 1 if aa in ['FWY'] else 0
                _has_hydroxyl = lambda aa: 1 if aa in ['ST'] else 0 #should we add TYR??
                _has_sulfur = lambda aa: 1 if aa in ['CM'] else 0

                for i, aa in enumerate(AAs):
                    self.AA_to_id[aa]  = {'id': len(self.AA_to_id),
                                          'hydro': _hydro[i],
                                          'molweight': _molarweight[i],
                                          'pol': _is_polar(aa),
                                          'arom': _is_aromatic(aa),
                                          'sulf': _has_sulfur(aa),
                                          'OH': _has_hydroxyl(aa)}
                for aa in AAs:
                    self.AA_to_id[aa] = len(self.AA_to_id)
                # get the inverse:
                self.id_to_AA = {}
                for aa, id in self.AA_to_id.items():
                    self.id_to_AA[id] = aa
                self.id_to_AA[42] = '_'

        elif self.mode.startswith('embed'):
            with open(self._opts._kmer2vec_kmerdict, "r") as vocab_file:
                for line in vocab_file:
                    fields = line.strip().split()
                    fields[0] = fields[0].strip('\'b')
                    self._kmer2id[fields[0]] = len(self._kmer2id)

    def _get_class_dict(self):
        with open(self._opts._ecfile, "r") as ec_fobj:
            for line in ec_fobj:
                fields = line.strip().split()
                if fields[1].endswith('.csv'):  #TODO delete this when error is fixed
                    fields[1] = fields[1].rstrip('.csv')

                if self._opts._labels == 'EC':
                    self.class_dict[fields[1]] = {'id': len(self.class_dict),
                                                  'size': int(fields[0]),
                if self._opts._labels == 'GO':
                    self.class_dict[fields[1].split('_')[1]] = {'id': len(self.class_dict),
                                                  'size': int(fields[0]),

    def _csv_EC_decoder(self, in_csv):
        line = in_csv.readline()
        fields = line.strip().split(';')
        name = fields[0]
        seq = fields[1]
        if self._opts._labels == 'EC':
            if fields[3].endswith('.csv'):
                fields[3] = fields[3].rstrip('.csv')
            EC_str = fields[3]
            EC_CLASS = 0 if self._opts._inferencemode else self.class_dict[EC_str]['id']
            label = [[EC_CLASS]] # we need a 2D array
        elif self._opts._labels == 'GO':
            GO_str = fields[2]
            GOs = 0 if self._opts._inferencemode else fields[2].split(',')
            if GOs[0].endswith('.csv'):
                GOs = [go.rstrip('.csv') for go in GOs]
            label = [[self.class_dict[go]['id']] for go in GOs] # returns a 2D array
        # TODO add an assertion for mode
        structure_str = fields[3]
        return name, seq, label, structure_str

    def _seq2tensor(self, seq):
        Does what you think it does.

        As for now we Fix the considered length to 200AA position,
        yielding a tensor of shape:
        [100, 196, 1]
        if self.mode == 'one_hot_padded':
            # first check if the sequence fits in the box:
            if len(seq) <= self._opts._windowlength:
                seq_matrix = np.ndarray(shape=(len(seq)), dtype=np.int32)
            # if sequence does not fit we clip it:
                seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            for i in range(len(seq_matrix)):
                seq_matrix[i] = self.AA_to_id[seq[i]]
            start_pos = 0 #because our sequence sits at the beginning of the box
            length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), 20))
            # pad the sequence to the boxsize:
            npad = ((0, self._opts._windowlength-length), (0, 0))
            padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            padded_seq_matrix = np.transpose(padded_seq_matrix)
            del oh_seq_matrix, seq_matrix

            # seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            # for i in range(len(seq_matrix)):
            #     seq_matrix[i] = self.AA_to_id[seq[i]]
            # start_pos = 0 #because our sequence sits at the beginning of the box
            # length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            # oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), self._opts._depth))
            # pad the sequence to the boxsize:
            # npad = ((0, self._opts._windowlength-length), (0, 0))
            # padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            # padded_seq_matrix = np.transpose(padded_seq_matrix)
            # del oh_seq_matrix, seq_matrix

            return padded_seq_matrix, start_pos, length #true length 1 based

        elif self.mode == 'one_hot_padded_physchem':
            # TODO implement this shit

        elif self.mode.startswith('embed'):
            k = self._opts._kmersize
            # split the sequence into words
            frame_words = [seq[start:start + k]
                           for start in range(0, len(seq))]

            # determine the vector of IDs to lookup simultaneously in the embedding
            frame_ids = [self._kmer2id[w] for w in frame_words if len(w) == k]

            seq_matrix = np.zeros(shape=(self._opts._windowlength), dtype=np.int32)
            for i in range(len(seq_matrix)):
                    seq_matrix[i] = frame_ids[i]
                except IndexError: #means no more frames

            start_pos = 0
            # pad the sequence to 1000
            length = len(frame_words) if len(frame_words) <= self._opts._windowlength else self._opts._windowlength
            # npad = ((0, self._opts._windowlength-length))
            # padded_seq_matrix = np.pad(seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            padded_seq_matrix = seq_matrix
            del seq_matrix
            return padded_seq_matrix, start_pos, length

            print("Error: MODE must be of ['one_hot_padded', 'one_hot_padded_physchem', 'embed']")

    def _get_structure(self, structure_str, seq_length):
        Construct a One Hot Encoded Tensor with height = self._structure_dims, width = self._windowlength
        :param structure_str: str
        the entry in the swissprot csv corresponding to the FT fields in the swissprot textfile download
        Example format:
        [('TURN', '11', '14'), ('HELIX', '19', '27'), ('STRAND', '32', '36'), ('HELIX', '45', '54'),
        ('STRAND', '59', '69'), ('STRAND', '72', '80'), ('HELIX', '86', '96'), ('HELIX', '99', '112'),
        ('HELIX', '118', '123'), ('HELIX', '129', '131'), ('HELIX', '134', '143'), ('STRAND', '146', '149'),
        ('HELIX', '150', '156'), ('STRAND', '157', '159'), ('HELIX', '173', '182'), ('STRAND', '186', '189'),
        ('HELIX', '192', '194'), ('HELIX', '199', '211'), ('STRAND', '216', '221'), ('HELIX', '226', '239'),
        ('STRAND', '242', '246'), ('HELIX', '272', '275'), ('HELIX', '277', '279'), ('STRAND', '283', '285')]
        # if there is info about the structure:
        if structure_str != '[]':
            # get an array of len length:
            structure = np.ones([seq_length])
            # modify the structure str:
            # TODO: Improve the super ugly hack with a proper regex
            structure_str = re.sub('[\'\[\]\(]', '', structure_str)
            structure_features = [j.strip(', ').split(', ') for j in structure_str.strip(')').split(')')]

            for ft in structure_features:
                # get the ID for the ft:
                id_to_write = self.structure_dict[ft[0]]
                start = int(ft[1])
                end = int(ft[2])
                for i in range(start, end+1):
                    structure[i] = id_to_write
            # encode it One-Hot:

            #oh_structure_matrix = np.reshape(self.structure_enc.fit_transform(np.reshape(structure, [1, -1])),
                                             #[len(structure), self._opts._structuredims])
            # now pad it up to windowlength:
            # npad = ((0, self._opts._windowlength-seq_length), (0, 0))
            # padded_structure_matrix = np.pad(oh_structure_matrix, pad_width=npad, mode='constant', constant_values=0)
            # padded_structure_matrix = np.transpose(padded_structure_matrix)
            npad = ((0, self._opts._windowlength-seq_length))
            padded_structure_matrix = np.pad(structure, pad_width=npad, mode='constant', constant_values=0)
            #assert padded_structure_matrix.shape[1] == self._opts._windowlength

            # return only zeros if there is no information about the structure
            padded_structure_matrix = np.zeros([self._opts._windowlength])

        return padded_structure_matrix

    def _process_csv(self, queue):
        pls infer from name.
        _, seq, labels, structure_str = self._csv_EC_decoder(queue)
        seq_matrix, start_pos, length = self._seq2tensor(seq)
        structure_tensor = self._get_structure(structure_str, length)
        # encode the label one_hot:
        oh_label_tensor = self.label_enc.fit_transform(labels) # of shape [1, n_classes]
        classes = oh_label_tensor.shape[0]
        # open an array full of zeros to add the labels to
        oh_labels = np.zeros(self._opts._nclasses)
        for c in range(classes):
            oh_labels += oh_label_tensor[c]

        oh_labels = np.expand_dims(oh_labels, axis=0)

        return oh_labels, seq_matrix, structure_tensor, start_pos, length

    def generate_garbage_sequence(self):
        Generates a sequence full of garbage, e.g. a obviously non functional sequence.
        modes = ['complete_random', 'pattern', 'same']
        mode = modes[random.randint(0, 2)]

        # get the length of the protein
        length = random.randint(175, self._opts._windowlength-1)

        if mode == 'pattern':
            # Generate a repetitive pattern of 5 AminoAcids to generate the prot
            # get a random nr of AAs to generate the pattern:
            AA_nr = random.randint(2, 5)
            # get an index for each AA in AA_nr
            idxs = []
            for aa in range(AA_nr):
                idx_found = False
                while not idx_found:
                    aa_idx = random.randint(0, 19)
                    if not aa_idx in idxs:
                        idx_found = True
            reps = math.ceil(length/AA_nr)
            seq = reps * idxs
            length = len(seq)

        elif mode == 'complete_random':
            # print('complete_random')
            seq = []
            for aa in range(length):
                # get an idx for every pos in length:
                idx = random.randint(0, 19)

        elif mode == 'same':
            # print('ONE')
            AA = random.randint(0, 19)
            seq = length * [AA]

        label = np.zeros([self._opts._nclasses])
        label = np.expand_dims(label, axis=0)
        garbage_label = np.asarray([1])
        garbage_label = np.expand_dims(garbage_label, axis=0)
        oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq, (1, -1))), (len(seq), 20))
        # pad the sequence to the boxsize:
        npad = ((0, self._opts._windowlength-length), (0, 0))
        padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
        padded_seq = np.transpose(padded_seq_matrix)
        return padded_seq, label, garbage_label

    def _bytes_feature(self, value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _float_feature(self, value):
        return tf.train.Feature(bytes_list=tf.train.FloatList(value=[value]))

    def _int64_feature(self, value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def example_to_TFrecords(self, is_train, garbage_percentage=0.2, structure=True):
        include_garbage = False if garbage_percentage == 0 else True

        # determine how many files we need to write:
        if is_train:
            length_data_set = _count_lines(self._opts._traindata)
            batch_files_name = os.path.basename(self._opts._traindata) + 'train_batch_{}'.format(str(self._opts._windowlength))
            in_csv = self.traindata

            length_data_set = _count_lines(self._opts._validdata)
            batch_files_name = os.path.basename(self._opts._validdata) + 'valid_batch_{}'.format(str(self._opts._windowlength))
            in_csv = self.validdata

        files_to_write = np.int32(np.ceil(length_data_set*(1+garbage_percentage)*2
                                          / float(self.examples_per_file))) # write every thing twice

        for n in range(1, files_to_write+1):
            file_path = os.path.join(self._opts._batchesdir, batch_files_name) + '_' + str(n)
            self.writer = tf.python_io.TFRecordWriter(file_path)

            if structure:

                for i in range(self.examples_per_file):
                    if include_garbage and  i % int(1/garbage_percentage) == 0:
                        # print("garbage_seq")
                        seq_tensor, label, garbage_label = self.generate_garbage_sequence()
                        structure_label = np.zeros([self._opts._windowlength])

                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength), "%s" % str(seq_tensor.shape)
                        assert label.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = label.tostring()
                        garbage_label_raw = garbage_label.tostring()
                        structure_label_raw = structure_label.tostring()

                        example = tf.train.Example(
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'structure_depth': self._int64_feature(self._opts._structuredims),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                                'garbage_label_raw': self._bytes_feature(garbage_label_raw),
                                'structure_label_raw': self._bytes_feature(structure_label_raw),
                        # print("validseq")
                            oh_labels, seq_tensor, structure_label, _, _ = self._process_csv(in_csv)

                        except IndexError: # catches error from csv_decoder -> reopen the file:
                            if is_train:
                                self.traindata = open(self._opts._traindata, 'r')
                                in_csv = self.traindata
                                self.validdata = open(self._opts._validdata, 'r')
                                in_csv = self.validdata
                            oh_labels, seq_tensor, structure_label, _, _ = self._process_csv(in_csv)

                        garbage_label = np.asarray([0]) # NOT garbage
                        garbage_label = np.expand_dims(garbage_label, axis=0)

                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength)
                        assert oh_labels.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = oh_labels.tostring()
                        garbage_label_raw = garbage_label.tostring()
                        structure_label_raw = structure_label.tostring()

                        example = tf.train.Example(
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'structure_depth': self._int64_feature(self._opts._structuredims),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                                'garbage_label_raw': self._bytes_feature(garbage_label_raw),
                                'structure_label_raw': self._bytes_feature(structure_label_raw),

            elif not structure:

                for i in range(self.examples_per_file):
                    if include_garbage and  i % int(1/garbage_percentage) == 0:
                        # print("garbage_seq")
                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength), "%s" % str(seq_tensor.shape)
                        assert label.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = label.tostring()

                        example = tf.train.Example(
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                        # print("validseq")
                            oh_labels, seq_tensor, _, _, _ = self._process_csv(in_csv)

                        except IndexError: # catches error from csv_decoder -> reopen the file:
                            if is_train:
                                self.traindata = open(self._opts._traindata, 'r')
                                in_csv = self.traindata
                                self.validdata = open(self._opts._validdata, 'r')
                                in_csv = self.validdata
                            oh_labels, seq_tensor, _, _, _ = self._process_csv(in_csv)

                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength)
                        assert oh_labels.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = oh_labels.tostring()

                        example = tf.train.Example(
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),

    def embed_and_to_TFrecords(self, is_train):
        Process the dataset to embedded sequences, and wirte the sequences as TF records.
        # ensure _batches_dir is correct:
        assert self._opts._batchesdir.endswith('embed/')

        # create batchesdir if not exists
        if not os.path.exists(self._opts._batchesdir):

        # collect the file for train/valid
        if is_train:
            length_data_set = _count_lines(self._opts._traindata)
            batch_files_name = os.path.basename(self._opts._traindata) + 'train_batch_{}'.format(str(self._opts._windowlength))
            in_csv = self.traindata

            length_data_set = _count_lines(self._opts._validdata)
            batch_files_name = os.path.basename(self._opts._validdata) + 'valid_batch_{}'.format(str(self._opts._windowlength))
            in_csv = self.validdata

        # "placeholder" for length:
        length_node = tf.placeholder(dtype=tf.int32)

        # load the embedding and construct a session to do the lookup!
        with tf.Session() as sess:
            with tf.variable_scope('kmer2vec') as vs:
                embedding = tf.get_variable('embedding',
                #embedding_saver = tf.train.Saver({"w_out": embedding})
                embedding_saver = tf.train.Saver({"n_emb": embedding})
                embedding_saver.restore(sess, tf.train.latest_checkpoint(

            with tf.variable_scope('process_sequence') as vs:
                sequence_node = tf.placeholder(tf.int32, shape=[self._opts._windowlength])

                # slice the sequence and extract the words to be embedded:
                #true_seq = sequence_node[:length_node]
                # look it up
                embedded_sequence = tf.transpose(tf.nn.embedding_lookup(embedding, sequence_node))
                embedded_sequence = tf.reshape(embedded_sequence, [self._opts._embeddingdim, self._opts._windowlength])

            # determine how many files we need to write:
            files_to_write = np.int32(np.ceil(length_data_set * 2 / float(self.examples_per_file))) # write every thing twice

            for n in range(1, files_to_write+1):
                file_path = os.path.join(self._opts._batchesdir, batch_files_name) + '_' + str(n)
                self.writer = tf.python_io.TFRecordWriter(file_path)

                for _ in range(self.examples_per_file):

                        oh_labels, seq_tensor, _, length = self._process_csv(in_csv)

                    except IndexError: # catches error from csv_decoder -> reopen the file:
                        if is_train:
                            self.traindata = open(self._opts._traindata, 'r')
                            in_csv = self.traindata
                            self.validdata = open(self._opts._validdata, 'r')
                            in_csv = self.validdata
                        oh_labels, seq_tensor, _, length = self._process_csv(in_csv)

                    feed_dict = {sequence_node: seq_tensor, length_node: length}

                    embedded_seq =, feed_dict=feed_dict) # will return a list
                    embedded_seq = np.asarray(embedded_seq)

                    self._embed_to_TFrecords(embedded_seq, oh_labels)

    def _embed_to_TFrecords(self, embedded_seq, oh_label):

        assert embedded_seq.shape == (self._opts._embeddingdim, self._opts._windowlength)
        assert oh_label.shape == (1, self._opts._nclasses)

        #print(np.argmax(oh_label, axis=1))

        # convert the features to a raw string:
        seq_raw = embedded_seq.tostring()
        label_raw = oh_label.tostring()

        example = tf.train.Example(
                'windowlength': self._int64_feature(self._opts._windowlength),
                'depth': self._int64_feature(self._opts._embeddingdim),
                'label_classes': self._int64_feature(self._opts._nclasses),
                'seq_raw': self._bytes_feature(seq_raw),
                'label_raw': self._bytes_feature(label_raw),


    def produce_train_valid(self):
        if self.mode.startswith('one_hot'):
            #self.example_to_TFrecords(is_train=True, garbage_percentage=0, structure=False)
            self.example_to_TFrecords(is_train=False, garbage_percentage=0, structure=False)

        elif self.mode.startswith('embed'):

def plot_histogram(log_file, save_dir):
    count_dict = {}
    with open(log_file, "r") as in_fobj:
        for line in in_fobj:
            pred_labels = line.strip().split()
            for label in pred_labels:
                    count_dict[label] += 1
                except KeyError:
                    count_dict[label] = 0
    bars = [count_dict[label] for label in count_dict.keys()]
    labels = [label for label in count_dict.keys()]
    fig, ax = plt.subplots()
    ax = barplot(x=bars, y=labels), 'negative_test.png'))

def plot_simple_curve(x, y, title, legend, xname, yname, filename):
    fig = plt.figure()
    plt.plot(x, y, color="red", lw=2, label=legend)
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.legend(loc="lower right")

def _count_lines(file_path):
    count = 0
    with open(file_path, "r") as fobj:
        for line in fobj:
            count += 1
    return count

def _add_var_summary(var, name, collection=None):
    """ attaches a lot of summaries to a given tensor"""
    with tf.name_scope(name):
        with tf.name_scope('summaries'):
            mean = tf.reduce_mean(var)
            tf.summary.scalar('mean', mean, collections=collection)
            with tf.name_scope('stddev'):
                stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
            tf.summary.scalar('stddev', stddev, collections=collection)
            tf.summary.scalar('max', tf.reduce_max(var), collections=collection)
            tf.summary.scalar('min', tf.reduce_min(var), collections=collection)
            tf.summary.histogram('histogram', var, collections=collection)

def _variable_on_cpu(name, shape, initializer, trainable):
    """ Helper function to get a variable stored on cpu"""
    with tf.device('/cpu:0'): #TODO will this work?
        dtype = tf.float32
        var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable)
    #dtf.add_to_collection('CPU', var)
    return var

def softmax(X, theta = 1.0, axis = None):
    Compute the softmax of each element along an axis of X.

    X: ND-Array. Probably should be floats.
    theta (optional): float parameter, used as a multiplier
        prior to exponentiation. Default = 1.0
    axis (optional): axis to compute values along. Default is the
        first non-singleton axis.

    Returns an array the same size as X. The result will sum to 1
    along the specified axis.

    # make X at least 2d
    y = np.atleast_2d(X)

    # find axis
    if axis is None:
        axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)

    # multiply y against the theta parameter,
    y = y * float(theta)

    # subtract the max for numerical stability
    y = y - np.expand_dims(np.max(y, axis = axis), axis)

    # exponentiate y
    y = np.exp(y)

    # take the sum along the specified axis
    ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)

    # finally: divide elementwise
    p = y / ax_sum

    # flatten if X was 1D
    if len(X.shape) == 1: p = p.flatten()

    return p