import matplotlib
matplotlib.use('Agg')
import tensorflow as tf
import numpy as np
import re, os
import random
import math
import json
import matplotlib.pyplot as plt
import sklearn.metrics
from seaborn import barplot, set_style
from sklearn.preprocessing import OneHotEncoder
from collections import OrderedDict


class OptionHandler():
    def __init__(self, config_dict):
        #self._usefp16 = config_dict['usefp16']
        self.config = config_dict
        self._name = config_dict['model_name']
        self._thresholdnum = config_dict['threshold_num']
        self._gpu = config_dict['gpu']
        self._allowsoftplacement = config_dict['allow_softplacement']
        self._numepochs = config_dict['num_epochs']
        self._numsteps = config_dict['num_steps']
        self._kmer2vec_embedding = config_dict['kmer2vec_embedding']
        self._kmer2vec_kmerdict = config_dict['kmer2vec_kmerdict']
        self._kmer2vec_kmercounts = config_dict['kmer2vec_kmercounts']
        self._kmersize = config_dict['kmer_size']
        self._nkmers = config_dict['n_kmers']
        self._embeddingdim = config_dict['embedding_dim']
        self._depth = config_dict['depth']
        self._structuredims = config_dict['structure_dims']
        self._traindata = config_dict['train_data']
        self._validdata = config_dict['valid_data']
        self._batchesdir = config_dict['batches_dir']
        self._inferencedata = config_dict['inference_data']
        self._inferencemode = True if config_dict['inference_mode'] == 'True' else False
        self._labels = config_dict['labels']
        self._nclasses = config_dict['n_classes']
        self._topk = config_dict['topk']
        self._classbalancing = True if config_dict['class_balancing'] == 'True' else False
        self._maxclassinbalance = config_dict['maxclass_inbalance']
        self._dropoutrate = config_dict['dropoutrate']
        self._learningrate = config_dict['learning_rate']
        self._epsilon = config_dict['epsilon']
        self._batchsize = config_dict['batch_size']
        self._batchgenmode = config_dict['batchgen_mode']
        self._windowlength = config_dict['window_length']
        self._minlength = config_dict['min_length']
        self._numthreads = config_dict['num_threads'] #TODO ASSERT THIS NUMBER!!!!!!!!
        self._restorepath = config_dict['restore_path']
        self._restore = True if config_dict['restore'] == 'True' else False
        self._debug = True if config_dict['debug'] == 'True' else False
        self._ecfile = config_dict['EC_file']
        self._summariesdir = config_dict['summaries_dir'] # for tensorboard
        self._summariesdir = self._summariesdir + '_{l}_{n}_{w}_{g}_{b}_{lr}_{e}'.format(g=self._batchgenmode,
                                                                                   w=self._windowlength,
                                                                               n=self._nclasses,
                                                                        b=self._batchsize,
                                                                        lr=self._learningrate,
                                                                        e=self._epsilon,
                                                                        l=self._labels)
        self._seqfile = config_dict['seqfile']
        self._survivalpop = config_dict['survival_pop']
        self._generations = config_dict['generations']
        self._systematic = config_dict['systematic']
        self._muts_per_gen = config_dict['muts_per_gen']
        self._decrease_muts_after_gen = config_dict['decrease_muts_after_gen']

        if not os.path.exists(self._summariesdir):
            os.makedirs(self._summariesdir)
        if not os.path.exists(self._batchesdir):
            os.makedirs(self._batchesdir)

    def write_dict(self):
        """
        Store the config_dict to disc in the save_dir
        :return:
        """
        with open(os.path.join(self._summariesdir, 'config_dict.JSON'), "w") as config_dict:
            json.dump(self.config, config_dict)


class RocTracker():
    def __init__(self, optionhandler):
        self._opts = optionhandler
        self.metrics_path = os.path.join(self._opts._summariesdir, 'metrics')
        if not os.path.exists(self.metrics_path):
            os.mkdir(self.metrics_path)
        self.metrics_file = open(os.path.join(self.metrics_path, 'metrics.csv'), "w")
        self.roc_score = []
        self.roc_labels = []
        self.pred_positives_sum = np.zeros(self._opts._nclasses)
        self.actual_positives_sum = np.zeros(self._opts._nclasses)
        self.true_positive_sum = np.zeros(self._opts._nclasses)
        self.num_calculations = 0

    def update(self, sigmoid_logits, true_labels):
        """
        update the ROC tracker, with the predictions on one batch made during validation
        """
        # threshold this thing
        # we consider a class "predicted" if it's sigmoid activation is higher than 0.5 (predicted labels)
        batch_predicted_labels = np.greater(sigmoid_logits, 0.5)
        batch_predicted_labels = batch_predicted_labels.astype(float)


        batch_pred_pos = np.sum(batch_predicted_labels, axis=0) #sum up along the batch dim, keep the channels
        batch_actual_pos = np.sum(true_labels, axis=0) #sum up along the batch dim, keep the channels
        # calculate the true positives:
        batch_true_pos = np.sum(np.multiply(batch_pred_pos, batch_actual_pos), axis=0)

        # and update the counts
        self.pred_positives_sum += batch_pred_pos #what the model said
        self.actual_positives_sum += batch_actual_pos #what the labels say
        self.true_positive_sum += batch_true_pos # where labels and model predictions>0.5 match

        assert len(self.true_positive_sum) == self._opts._nclasses

        # add the predictions to the roc_score tracker
        self.roc_score.append(sigmoid_logits)
        self.roc_labels.append(true_labels)

    def calc_and_save(self, logfile):
        """
        Calculate the ROC curve with AUC value for the collected test values (roc_scores, roc_labels).
        Writes everything to files, plots curves and resets the Counters afterwards.
        """
        self.metrics_file = open(os.path.join(self.metrics_path, 'metrics.csv'), "w")

        self.num_calculations += 1

        # concat score and labels along the batchdim -> a giant test batch
        self.roc_score = np.concatenate(self.roc_score, axis=0)
        self.roc_labels = np.concatenate(self.roc_labels, axis=0)

        # get the total number of seqs we tested on:
        logfile.write('[*] Calculating metrics\n')
        test_set_size = self.roc_labels.shape[0]

        # do the calculations
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true=np.reshape(self.roc_labels, (-1)),
                                                         y_score=np.reshape(self.roc_score, (-1)))
        auc = sklearn.metrics.auc(fpr, tpr)

        precision_arr, recall_arr, thresholds = sklearn.metrics.precision_recall_curve(
            y_true=np.reshape(self.roc_labels, (-1)), probas_pred=np.reshape(self.roc_score, (-1))) # micro-average PR curve

        # now save everything to the metrics.csv
        # metrics = [fpr, tpr, auc,
        #            thresholds, precision_arr, recall_arr, thresholds]
        # metrics_as_str = []
        # for m in metrics:
        #     if isinstance(m, np.ndarray):
        #         m_str = ','.join(str(e) for e in m.tolist())
        #     else:
        #         m_str = str(m)
        #     metrics_as_str.append(m_str)
        #
        # line = ';'.join(metrics_as_str)
        # line += '\n'
        # self.metrics_file.write(line)
        #self.metrics_file.flush()

        # write get the max, min and avg scores for each class:
        # determine the scores for the labels
        scores = self.roc_score * self.roc_labels

        mean_scores = np.mean(scores, axis=0)
        assert mean_scores.shape[0] == self._opts._nclasses
        max_scores = np.amax(scores, axis=0)
        assert max_scores.shape[0] == self._opts._nclasses
        min_scores = np.amin(scores, axis=0)
        assert min_scores.shape[0] == self._opts._nclasses

        self.metrics_file.write(str(mean_scores) + '\n')
        self.metrics_file.write(str(max_scores) + '\n')
        self.metrics_file.write(str(min_scores) + '\n')

        self.metrics_file.close()

        # get printable metrics (for log file)
        precision_class = self.true_positive_sum / np.maximum(1, self.pred_positives_sum) # where predPositives_sum == 0, tp_sum is also 0
        recall_class = self.true_positive_sum / np.maximum(1, self.actual_positives_sum) # where actualPositives_sum == 0, tp_sum is also 0
        precision = np.sum(self.true_positive_sum) / np.sum(self.pred_positives_sum)
        recall = np.sum(self.true_positive_sum) / np.sum(self.actual_positives_sum)
        f1 = 2*precision*recall / (precision + recall)
        logfile.write("[*] Tested on %d seqs, "
                      "precision %.2f%%, "
                      "recall %.2f%%, "
                      "F1 %.2f%%\n" % (test_set_size, precision, recall, f1))
        logfile.flush()



        #plot ROC:
        plot_simple_curve(x=fpr, y=tpr, title=self._opts._name + '_ROC_curve',
                          legend=self._opts._name + ' (AUC = %0.4f)' % auc,
                          xname='False positive rate', yname='True positive rate',
                          filename=os.path.join(self.metrics_path, self._opts._name + '.roc_%d' % self.num_calculations))


        # PR curve
        plot_simple_curve(x=recall_arr, y=precision_arr,
                          title=self._opts._name + ' PR curve', legend=self._opts._name,
                          xname='Recall', yname='Precision',
                          filename=os.path.join(self.metrics_path, self._opts._name + '.precision_%d' % self.num_calculations))

        # reset the stats-collectors:
        self.roc_score = []
        self.roc_labels = []
        self.pred_positives_sum = np.zeros(self._opts._nclasses)
        self.actual_positives_sum = np.zeros(self._opts._nclasses)

        logfile.write('[*] Done testing.\n')


class StratifiedCounterDict(dict):
    def __missing__(self, key):
        self[key] = {'tp': 0,
                     'pred_p': 0,
                     }
        return self[key]


class BatchGenerator():
    def __init__(self, optionhandler, kmer2vec_embedding, kmer2id):
        self._opts = optionhandler
        self.mode = self._opts._batchgenmode # one of ['window', 'bigbox', 'dynamic']
        self._kmer2vec_embedding = kmer2vec_embedding
        self._kmer2id = kmer2id
        self.inferencedata = open(self._opts._inferencedata, 'r')
        self.traindata = open(self._opts._traindata, 'r')
        self.validdata = open(self._opts._validdata, 'r')
        self.AA_to_id = {}
        self.id_to_AA = {}
        self.class_dict = OrderedDict()
        self.id_to_class = OrderedDict()
        self._get_class_dict()
        self.embedding_dict = OrderedDict()
        # determine the number of batches for eval from lines in the validdata and the garbagepercentage
        self.garbage_percentage = 0.2
        self.garbage_count = 0 # a counter for generated garbage sequences
        self.eval_batch_nr = int(_count_lines(self._opts._validdata) * (1 + self.garbage_percentage) //
                              self._opts._batchsize)
        print('Initialized Batchgen with   batchsize: %d,   numeval_batches: %d at'
              '                            garbage_percentage: %f' % (self._opts._batchsize,
                                                                      self.eval_batch_nr,
                                                                      self.garbage_percentage))
        self.batches_per_file = 10000
        self.epochs = 2000
        self.curr_epoch = 0
        self.label_enc = OneHotEncoder(n_values=self._opts._nclasses, sparse=False)
        self.AA_enc = 'where we put the encoder for the AAs'

        if self.mode.startswith('one_hot'):
            AAs = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                   'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',]

            self.AA_enc = OneHotEncoder(n_values=self._opts._depth, sparse=False)

            if 'physchem' in self.mode:
                _hydro = [1.8, 2.5, -3.5, -3.5, 2.8,
                          -0.4, -3.2, 4.5, -3.9, 3.8,
                          1.9, -3.5, -1.6, -3.5, -4.5,
                          -0.8, -0.7, 4.2, -0.9, -1.3]
                _molarweight = [89.094, 121.154, 133.104, 147.131, 165.192,
                                75.067, 155.156, 131.175, 146.189, 131.175,
                                149.208, 132.119, 115.132, 146.146, 174.203,
                                105.093, 119.119, 117.148, 204.228, 181.191]
                _is_polar = lambda aa: 1 if aa in ['DEHKNQRSTY'] else 0
                _is_aromatic = lambda aa: 1 if aa in ['FWY'] else 0
                _has_hydroxyl = lambda aa: 1 if aa in ['ST'] else 0 #should we add TYR??
                _has_sulfur = lambda aa: 1 if aa in ['CM'] else 0

                for i, aa in enumerate(AAs):
                    self.AA_to_id[aa]  = {'id': len(self.AA_to_id),
                                          'hydro': _hydro[i],
                                          'molweight': _molarweight[i],
                                          'pol': _is_polar(aa),
                                          'arom': _is_aromatic(aa),
                                          'sulf': _has_sulfur(aa),
                                          'OH': _has_hydroxyl(aa)}
            else:
                for aa in AAs:
                    self.AA_to_id[aa] = len(self.AA_to_id)
                # get the inverse:
                self.id_to_AA = {}
                for aa, id in self.AA_to_id.items():
                    self.id_to_AA[id] = aa
                self.id_to_AA[42] = '_'

    def _get_class_dict(self):
        with open(self._opts._ecfile, "r") as ec_fobj:
            for line in ec_fobj:
                fields = line.strip().split()
                if fields[1].endswith('.csv'):  #TODO delete this when error is fixed
                    fields[1] = fields[1].rstrip('.csv')

                if self._opts._labels == 'EC':
                    self.class_dict[fields[1]] = {'id': len(self.class_dict),
                                                  'size': int(fields[0]),
                                                  }
                if self._opts._labels == 'GO':
                    self.class_dict[fields[1].split('_')[1]] = {'id': len(self.class_dict),
                                                                'size': int(fields[0]),
                                                                }

        # get a reverse dict:
        for key in self.class_dict.keys():
            self.id_to_class[self.class_dict[key]['id']] = key

    def _update_embedding_dict(self, name, labels):
        """
        Update the embedding dict for new entries. This is used on the fly as we perform inference
        batchgen
        """
        if len(self.embedding_dict) == 0:
            # add UNK token the first time this method is called
            self.embedding_dict['UNK'] = {}
            self.embedding_dict['UNK']['labels'] = ['UNK']
            self.embedding_dict['UNK']['id'] = 0 # we save 0 for the UNK token

        # check if the key (= name) is already in the dict:
        if name not in self.embedding_dict:
            assert len(self.embedding_dict) > 0
            self.embedding_dict[name] = {}
            self.embedding_dict[name]['labels'] = labels
            self.embedding_dict[name]['id'] = len(self.embedding_dict)
        else:
            print(name)
            print('WARNING: Overwrote value in embedding dict. Check infile for redundant sequences!')

    def _csv_EC_decoder(self, in_csv, encoded_labels=True):
        line = in_csv.readline()
        fields = line.strip().split(';')
        name = fields[0]
        seq = fields[1]
        if fields[2].endswith('.csv'): #TODO assert this
            fields[2] = fields[2].rstrip('.csv')
        if self._opts._labels == 'EC':
            EC_str = fields[2] #TODO assert this
            if encoded_labels:
                EC_CLASS = 0 if self._opts._inferencemode else self.class_dict[EC_str]['id']
                label = [[EC_CLASS]] # we need a 2D array
            else:
                label = EC_str
        elif self._opts._labels == 'GO':
            GO_str = fields[2] #TODO assert this
            GOs = 0 if self._opts._inferencemode else fields[2].split(',') #TODO assert this
            if encoded_labels:
                label = [[self.class_dict[go]['id']] for go in GOs] # returns a 2D array
            else:
                label = GOs
        # TODO add an assertion for mode
        return name, seq, label

    def _seq2tensor(self, seq):
        """
        Does what you think it does.

        As for now we Fix the considered length to 200AA position,
        yielding a tensor of shape:
        [100, 196, 1]
        """
        if self.mode == 'one_hot_padded':
            # first check if the sequence fits in the box:
            if len(seq) <= self._opts._windowlength:
                seq_matrix = np.ndarray(shape=(len(seq)), dtype=np.int32)
            # if sequence does not fit we clip it:
            else:
                seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            for i in range(len(seq_matrix)):
                seq_matrix[i] = self.AA_to_id[seq[i]]
            start_pos = 0 #because our sequence sits at the beginning of the box
            length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), 20))
            # pad the sequence to the boxsize:
            npad = ((0, self._opts._windowlength-length), (0, 0))
            padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            padded_seq_matrix = np.transpose(padded_seq_matrix)
            del oh_seq_matrix, seq_matrix

            # seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            # for i in range(len(seq_matrix)):
            #     seq_matrix[i] = self.AA_to_id[seq[i]]
            # start_pos = 0 #because our sequence sits at the beginning of the box
            # length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            # oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), self._opts._depth))
            # pad the sequence to the boxsize:
            # npad = ((0, self._opts._windowlength-length), (0, 0))
            # padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            # padded_seq_matrix = np.transpose(padded_seq_matrix)
            # del oh_seq_matrix, seq_matrix

            return padded_seq_matrix, start_pos, length #true length 1 based


        elif self.mode == 'one_hot_padded_physchem':
            # TODO implement this shit
            pass

        else:
            print("Error: MODE must be of ['one_hot_padded', 'one_hot_padded_physchem']")

    def _encode_single_seq(self, seq, desired_label=None):
        """
        Encode single sequence.
        """
        seq_matrix, start_pos, length = self._seq2tensor(seq)
        # look up the label in the class_dict:
        if desired_label:
            desired_label_ID = self.class_dict[desired_label]['id']

            # encode label one_hot:
            oh_label = self.label_enc.fit_transform([[desired_label_ID]]) # of shape [1, n_classes]
            return oh_label, seq_matrix, start_pos, length

        else:
            return seq_matrix

    def _process_csv(self, queue, return_name=True, encode_labels=True):
        """
        pls infer from name.
        """
        name, seq, label = self._csv_EC_decoder(queue, encoded_labels=encode_labels)
        seq_matrix, start_pos, end_pos = self._seq2tensor(seq)
        if return_name:
            return name, label, seq_matrix, start_pos, end_pos
        else:
            return label, seq_matrix, start_pos, end_pos

    def generate_garbage_sequence(self, return_name=False):
        """
        Generates a sequence full of garbage, e.g. a obviously non functional sequence.
        :return:
        """
        modes = ['complete_random', 'pattern', 'same']
        mode = modes[random.randint(0, 2)]
        self.garbage_count += 1

        # get the length of the protein
        length = random.randint(175, self._opts._windowlength-10) #enforce padding

        if mode == 'pattern':
            #print('pattern')
            # Generate a repetitive pattern of 5 AminoAcids to generate the prot
            # get a random nr of AAs to generate the pattern:
            AA_nr = random.randint(2, 5)
            # get an index for each AA in AA_nr
            idxs = []
            for aa in range(AA_nr):
                idx_found = False
                while not idx_found:
                    aa_idx = random.randint(0, 19)
                    if not aa_idx in idxs:
                        idxs.append(aa_idx)
                        idx_found = True
            reps = math.ceil(length/AA_nr)
            seq = reps * idxs
            length = len(seq)

        elif mode == 'complete_random':
            # print('complete_random')
            seq = []
            for aa in range(length):
                # get an idx for every pos in length:
                idx = random.randint(0, 19)
                seq.append(idx)

        elif mode == 'same':
            # print('ONE')
            AA = random.randint(0, 19)
            seq = length * [AA]

        label = np.zeros([self._opts._nclasses])
        label = np.expand_dims(label, axis=0)
        garbage_label = np.asarray([1])
        garbage_label = np.expand_dims(garbage_label, axis=0)
        oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq, (1, -1))), (len(seq), 20))
        # pad the sequence to the boxsize:
        npad = ((0, self._opts._windowlength-length), (0, 0))
        padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
        padded_seq = np.transpose(padded_seq_matrix)
        if return_name:
            # return a sequence ID to identify the generated sequence
            # generate a "random" name
            name = 'g%d' % self.garbage_count
            return name, padded_seq, label, garbage_label
        else:
            return padded_seq, label, garbage_label

    def generate_random_data_batch(self):
        seq_tensor_batch = tf.random_normal([self._opts._batchsize, self._opts._embeddingdim, self._opts._windowlength, 1])

        label_batch = [np.random.randint(1,self._opts._nclasses) for _ in range(self._opts._batchsize)]
        index_batch = [tf.constant(label) for label in label_batch]
        label_tensor = tf.stack(index_batch)
        onehot_labelled_batch = tf.one_hot(indices=tf.cast(label_tensor, tf.int32),
                                           depth=self._opts._nclasses)
        return seq_tensor_batch, onehot_labelled_batch

    def generate_single_seq_batch(self, seq, desired_label):
        """
        Generate a batch from a single sequence. (means, the whole batch is the same seq) we need
        this as the network is made for a certain batchsize. If we restore the model we need to keep this batchsize constant.
        """
        seq_tensors = []
        label_batch = []
        for _ in range(self._opts._batchsize):
            oh_label, seq_tensor, _, _ = self._encode_single_seq(seq, desired_label)
            label_batch.append(oh_label)
            seq_tensors.append(seq_tensor)
        batch_tensor = np.expand_dims(np.stack(seq_tensors, axis=0), axis=-1)
        label_tensor = np.concatenate(label_batch, axis=0)
        # drop the first dimension

        return batch_tensor, label_tensor

    def generate_inference_batch(self):
        """
        Generates a batch to infer the labels for sequences, as everything is fed into the same graph,
        we use the same kind of preprocessing and basically the same function as generate_batch but on
        another file.
        :return: batch, labels (empty), positions
        """
        seq_tensors = []
        label_batch = []
        positions = np.ndarray([self._opts._batchsize, 2])
        lengths = np.ndarray([self._opts._batchsize])
        in_csv = self.inferencedata
        for i in range(self._opts._batchsize):
            try:
                """ Note that this is not shuffled! """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,
                                                                      encode_labels=True)
                label_batch.append(ECclass)
                seq_tensors.append(seq_tensor)
            except IndexError: # catches error from csv_decoder
                # reopen the file:
                in_csv.close()
                # TODO: implement file shuffling when we reopen the file
                self.inferencedata = open(self._opts._inferencedata, 'r')
                in_csv = self.inferencedata
                """ redo """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,
                                                                            encode_labels=True)
                label_batch.append(ECclass)
                seq_tensors.append(seq_tensor)

                positions[i, 0] = start_pos
                positions[i, 1] = end_pos
                lengths[i] = end_pos

        batch = np.stack(seq_tensors, axis=0)

        if 'spp' in self.mode:
            return batch, label_batch, lengths
        if 'padded' in self.mode:
            return batch, label_batch, lengths
        else:
            return batch, label_batch, positions

    def generate_batch(self, is_train):
        """
            generate batches to train the model:
            as we use the sparse softmax ce, we DO NOT NEED TO ONE HOT ENCODE OUR LABELS!
        """
        seq_tensors = []
        label_batch = []
        positions = np.ndarray([self._opts._batchsize, 2])
        lengths = np.ndarray([self._opts._batchsize])
        if is_train:
            in_csv = self.traindata
        else:
            in_csv = self.validdata
        for i in range(self._opts._batchsize):
            try:
                """ Note that this is not shuffled! """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,
                                                                            encode_labels=True)
                label_batch.append(ECclass)
                seq_tensors.append(seq_tensor)
            except IndexError: # catches error from csv_decoder
                # reopen the file:
                in_csv.close()
                # TODO: implement file shuffling when we reopen the file
                if is_train:
                    self.traindata = open(self._opts._traindata, 'r')
                    in_csv = self.traindata
                else:
                    self.validdata = open(self._opts._validdata, 'r')
                    in_csv = self.validdata
                """ redo """
                ECclass, seq_tensor, start_pos, end_pos = self._process_csv(in_csv, return_name=False,
                                                                            encode_labels=True)
                label_batch.append(ECclass)
                seq_tensors.append(seq_tensor)

                positions[i, 0] = start_pos
                positions[i, 1] = end_pos
                lengths[i] = end_pos

        batch = np.stack(seq_tensors, axis=0)

        if 'spp' in self.mode:
            return batch, label_batch, lengths
        if 'padded' in self.mode:
            return batch, label_batch, lengths
        else:
            return batch, label_batch, positions

    def generate_valid_batch(self, include_garbage=False):
        """
        Generates a batch to infer the labels for sequences, as everything is fed into the same graph,
        we use the same kind of preprocessing and basically the same function as generate_batch but on
        another file.
        :return: batch, labels (empty), positions
        """
        seq_tensors = []
        in_csv = self.validdata
        if not include_garbage:
            for i in range(self._opts._batchsize):
                try:
                    """ Note that this is not shuffled! """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                                                                      encode_labels=False)
                    self._update_embedding_dict(name, label)
                    seq_tensors.append(seq_tensor)
                except IndexError: # catches error from csv_decoder
                    # reopen the file:
                    in_csv.close()
                    # TODO: implement file shuffling when we reopen the file
                    self.validdata = open(self._opts._validdata, 'r')
                    in_csv = self.validdata
                    """ redo """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                                                                      encode_labels=False)
                    self._update_embedding_dict(name, label)
                    seq_tensors.append(seq_tensor)

        #
        elif include_garbage:
            num_garbage = math.ceil(self._opts._batchsize * self.garbage_percentage)
            for i in range(self._opts._batchsize - num_garbage):
                try:
                    """ Note that this is not shuffled! """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                                                                      encode_labels=False)
                    self._update_embedding_dict(name, label)
                    seq_tensors.append(seq_tensor)
                except IndexError: # catches error from csv_decoder
                    # reopen the file:
                    in_csv.close()
                    # TODO: implement file shuffling when we reopen the file
                    self.validdata = open(self._opts._validdata, 'r')
                    in_csv = self.validdata
                    """ redo """
                    name, label, seq_tensor, _, _ = self._process_csv(in_csv, return_name=True,
                                                                      encode_labels=False)
                    self._update_embedding_dict(name, label)
                    seq_tensors.append(seq_tensor)

            for i in range(num_garbage):
                name, seq_tensor, _, _ = self.generate_garbage_sequence(return_name=True)
                label = 'garbage'
                self._update_embedding_dict(name, label)
                seq_tensors.append(seq_tensor)

        batch = np.stack(seq_tensors, axis=0)
        batch = np.expand_dims(batch, axis=-1)
        return batch


class TFrecords_generator():
    def __init__(self, optionhandler):
        self._opts = optionhandler
        self.label_enc = OneHotEncoder(n_values=self._opts._nclasses, sparse=False)
        self.AA_enc = 'where we put the encoder for the AAs'
        self.mode = self._opts._batchgenmode # one of ['window', 'bigbox', 'dynamic']
        self._kmer2vec_embedding = 'kmer2vec_embedding'
        self._kmer2id = {}
        self.inferencedata = open(self._opts._inferencedata, 'r')
        self.traindata = open(self._opts._traindata, 'r')
        self.validdata = open(self._opts._validdata, 'r')
        self.AA_to_id = {}
        self.class_dict = {}
        self._get_class_dict()
        self.structure_dict = {}
        self.examples_per_file = 10000
        self.epochs = self._opts._numepochs
        self.curr_epoch = 0
        self.writer = 'where we put the writer'

        # get the structure_dict
        structure_forms = ['UNORDERED', 'HELIX', 'STRAND', 'TURN']
        assert len(structure_forms) == self._opts._structuredims-1
        for s in structure_forms:
            self.structure_dict[s] = len(self.structure_dict) + 1 #serve the 0 for NO INFORMATION
        #self.structure_enc = OneHotEncoder(n_values=self._opts._structuredims, sparse=False)

        if self.mode.startswith('one_hot'):
            AAs = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                   'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',]

                   #'X']
            self.AA_enc = OneHotEncoder(n_values=self._opts._depth, sparse=False)
            if 'physchem' in self.mode:
                _hydro = [1.8, 2.5, -3.5, -3.5, 2.8,
                          -0.4, -3.2, 4.5, -3.9, 3.8,
                          1.9, -3.5, -1.6, -3.5, -4.5,
                          -0.8, -0.7, 4.2, -0.9, -1.3]
                _molarweight = [89.094, 121.154, 133.104, 147.131, 165.192,
                                75.067, 155.156, 131.175, 146.189, 131.175,
                                149.208, 132.119, 115.132, 146.146, 174.203,
                                105.093, 119.119, 117.148, 204.228, 181.191]
                _is_polar = lambda aa: 1 if aa in ['DEHKNQRSTY'] else 0
                _is_aromatic = lambda aa: 1 if aa in ['FWY'] else 0
                _has_hydroxyl = lambda aa: 1 if aa in ['ST'] else 0 #should we add TYR??
                _has_sulfur = lambda aa: 1 if aa in ['CM'] else 0

                for i, aa in enumerate(AAs):
                    self.AA_to_id[aa]  = {'id': len(self.AA_to_id),
                                          'hydro': _hydro[i],
                                          'molweight': _molarweight[i],
                                          'pol': _is_polar(aa),
                                          'arom': _is_aromatic(aa),
                                          'sulf': _has_sulfur(aa),
                                          'OH': _has_hydroxyl(aa)}
            else:
                for aa in AAs:
                    self.AA_to_id[aa] = len(self.AA_to_id)
                # get the inverse:
                self.id_to_AA = {}
                for aa, id in self.AA_to_id.items():
                    self.id_to_AA[id] = aa
                self.id_to_AA[42] = '_'

        elif self.mode.startswith('embed'):
            with open(self._opts._kmer2vec_kmerdict, "r") as vocab_file:
                for line in vocab_file:
                    fields = line.strip().split()
                    fields[0] = fields[0].strip('\'b')
                    self._kmer2id[fields[0]] = len(self._kmer2id)
            #print(self._kmer2id)

    def _get_class_dict(self):
        with open(self._opts._ecfile, "r") as ec_fobj:
            for line in ec_fobj:
                fields = line.strip().split()
                if fields[1].endswith('.csv'):  #TODO delete this when error is fixed
                    fields[1] = fields[1].rstrip('.csv')

                if self._opts._labels == 'EC':
                    self.class_dict[fields[1]] = {'id': len(self.class_dict),
                                                  'size': int(fields[0]),
                                                  }
                if self._opts._labels == 'GO':
                    self.class_dict[fields[1].split('_')[1]] = {'id': len(self.class_dict),
                                                  'size': int(fields[0]),
                                                  }

    def _csv_EC_decoder(self, in_csv):
        line = in_csv.readline()
        fields = line.strip().split(';')
        name = fields[0]
        seq = fields[1]
        if self._opts._labels == 'EC':
            if fields[3].endswith('.csv'):
                fields[3] = fields[3].rstrip('.csv')
            EC_str = fields[3]
            EC_CLASS = 0 if self._opts._inferencemode else self.class_dict[EC_str]['id']
            label = [[EC_CLASS]] # we need a 2D array
        elif self._opts._labels == 'GO':
            GO_str = fields[2]
            GOs = 0 if self._opts._inferencemode else fields[2].split(',')
            if GOs[0].endswith('.csv'):
                GOs = [go.rstrip('.csv') for go in GOs]
            label = [[self.class_dict[go]['id']] for go in GOs] # returns a 2D array
        # TODO add an assertion for mode
        structure_str = fields[3]
        return name, seq, label, structure_str

    def _seq2tensor(self, seq):
        """
        Does what you think it does.

        As for now we Fix the considered length to 200AA position,
        yielding a tensor of shape:
        [100, 196, 1]
        """
        if self.mode == 'one_hot_padded':
            # first check if the sequence fits in the box:
            if len(seq) <= self._opts._windowlength:
                seq_matrix = np.ndarray(shape=(len(seq)), dtype=np.int32)
            # if sequence does not fit we clip it:
            else:
                seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            for i in range(len(seq_matrix)):
                seq_matrix[i] = self.AA_to_id[seq[i]]
            start_pos = 0 #because our sequence sits at the beginning of the box
            length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), 20))
            # pad the sequence to the boxsize:
            npad = ((0, self._opts._windowlength-length), (0, 0))
            padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            padded_seq_matrix = np.transpose(padded_seq_matrix)
            del oh_seq_matrix, seq_matrix

            # seq_matrix = np.ndarray(shape=(self._opts._windowlength), dtype=np.int32)
            # for i in range(len(seq_matrix)):
            #     seq_matrix[i] = self.AA_to_id[seq[i]]
            # start_pos = 0 #because our sequence sits at the beginning of the box
            # length = len(seq_matrix)  #true length (1 based)
            # now encode the sequence in one-hot
            # oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq_matrix, (1, -1))), (len(seq_matrix), self._opts._depth))
            # pad the sequence to the boxsize:
            # npad = ((0, self._opts._windowlength-length), (0, 0))
            # padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            # padded_seq_matrix = np.transpose(padded_seq_matrix)
            # del oh_seq_matrix, seq_matrix

            return padded_seq_matrix, start_pos, length #true length 1 based

        elif self.mode == 'one_hot_padded_physchem':
            # TODO implement this shit
            pass

        elif self.mode.startswith('embed'):
            k = self._opts._kmersize
            # split the sequence into words
            frame_words = [seq[start:start + k]
                           for start in range(0, len(seq))]

            # determine the vector of IDs to lookup simultaneously in the embedding
            frame_ids = [self._kmer2id[w] for w in frame_words if len(w) == k]

            seq_matrix = np.zeros(shape=(self._opts._windowlength), dtype=np.int32)
            for i in range(len(seq_matrix)):
                try:
                    seq_matrix[i] = frame_ids[i]
                except IndexError: #means no more frames
                    pass


            start_pos = 0
            # pad the sequence to 1000
            length = len(frame_words) if len(frame_words) <= self._opts._windowlength else self._opts._windowlength
            # npad = ((0, self._opts._windowlength-length))
            # padded_seq_matrix = np.pad(seq_matrix, pad_width=npad, mode='constant', constant_values=0)
            padded_seq_matrix = seq_matrix
            del seq_matrix
            return padded_seq_matrix, start_pos, length

        else:
            print("Error: MODE must be of ['one_hot_padded', 'one_hot_padded_physchem', 'embed']")

    def _get_structure(self, structure_str, seq_length):
        """
        Construct a One Hot Encoded Tensor with height = self._structure_dims, width = self._windowlength
        :param structure_str: str
        the entry in the swissprot csv corresponding to the FT fields in the swissprot textfile download
        Example format:
        [('TURN', '11', '14'), ('HELIX', '19', '27'), ('STRAND', '32', '36'), ('HELIX', '45', '54'),
        ('STRAND', '59', '69'), ('STRAND', '72', '80'), ('HELIX', '86', '96'), ('HELIX', '99', '112'),
        ('HELIX', '118', '123'), ('HELIX', '129', '131'), ('HELIX', '134', '143'), ('STRAND', '146', '149'),
        ('HELIX', '150', '156'), ('STRAND', '157', '159'), ('HELIX', '173', '182'), ('STRAND', '186', '189'),
        ('HELIX', '192', '194'), ('HELIX', '199', '211'), ('STRAND', '216', '221'), ('HELIX', '226', '239'),
        ('STRAND', '242', '246'), ('HELIX', '272', '275'), ('HELIX', '277', '279'), ('STRAND', '283', '285')]
        :return:
        """
        # if there is info about the structure:
        if structure_str != '[]':
            # get an array of len length:
            structure = np.ones([seq_length])
            # modify the structure str:
            # TODO: Improve the super ugly hack with a proper regex
            structure_str = re.sub('[\'\[\]\(]', '', structure_str)
            structure_features = [j.strip(', ').split(', ') for j in structure_str.strip(')').split(')')]

            for ft in structure_features:
                # get the ID for the ft:
                id_to_write = self.structure_dict[ft[0]]
                start = int(ft[1])
                end = int(ft[2])
                for i in range(start, end+1):
                    structure[i] = id_to_write
            # encode it One-Hot:

            #oh_structure_matrix = np.reshape(self.structure_enc.fit_transform(np.reshape(structure, [1, -1])),
                                             #[len(structure), self._opts._structuredims])
            # now pad it up to windowlength:
            # npad = ((0, self._opts._windowlength-seq_length), (0, 0))
            # padded_structure_matrix = np.pad(oh_structure_matrix, pad_width=npad, mode='constant', constant_values=0)
            # padded_structure_matrix = np.transpose(padded_structure_matrix)
            npad = ((0, self._opts._windowlength-seq_length))
            padded_structure_matrix = np.pad(structure, pad_width=npad, mode='constant', constant_values=0)
            #assert padded_structure_matrix.shape[1] == self._opts._windowlength

        else:
            # return only zeros if there is no information about the structure
            padded_structure_matrix = np.zeros([self._opts._windowlength])

        return padded_structure_matrix

    def _process_csv(self, queue):
        """
        pls infer from name.
        """
        _, seq, labels, structure_str = self._csv_EC_decoder(queue)
        seq_matrix, start_pos, length = self._seq2tensor(seq)
        structure_tensor = self._get_structure(structure_str, length)
        # encode the label one_hot:
        oh_label_tensor = self.label_enc.fit_transform(labels) # of shape [1, n_classes]
        classes = oh_label_tensor.shape[0]
        # open an array full of zeros to add the labels to
        oh_labels = np.zeros(self._opts._nclasses)
        for c in range(classes):
            oh_labels += oh_label_tensor[c]

        oh_labels = np.expand_dims(oh_labels, axis=0)

        return oh_labels, seq_matrix, structure_tensor, start_pos, length

    def generate_garbage_sequence(self):
        """
        Generates a sequence full of garbage, e.g. a obviously non functional sequence.
        :return:
        """
        modes = ['complete_random', 'pattern', 'same']
        mode = modes[random.randint(0, 2)]

        # get the length of the protein
        length = random.randint(175, self._opts._windowlength-1)

        if mode == 'pattern':
            #print('pattern')
            # Generate a repetitive pattern of 5 AminoAcids to generate the prot
            # get a random nr of AAs to generate the pattern:
            AA_nr = random.randint(2, 5)
            # get an index for each AA in AA_nr
            idxs = []
            for aa in range(AA_nr):
                idx_found = False
                while not idx_found:
                    aa_idx = random.randint(0, 19)
                    if not aa_idx in idxs:
                        idxs.append(aa_idx)
                        idx_found = True
            reps = math.ceil(length/AA_nr)
            seq = reps * idxs
            length = len(seq)

        elif mode == 'complete_random':
            # print('complete_random')
            seq = []
            for aa in range(length):
                # get an idx for every pos in length:
                idx = random.randint(0, 19)
                seq.append(idx)

        elif mode == 'same':
            # print('ONE')
            AA = random.randint(0, 19)
            seq = length * [AA]

        label = np.zeros([self._opts._nclasses])
        label = np.expand_dims(label, axis=0)
        garbage_label = np.asarray([1])
        garbage_label = np.expand_dims(garbage_label, axis=0)
        oh_seq_matrix = np.reshape(self.AA_enc.fit_transform(np.reshape(seq, (1, -1))), (len(seq), 20))
        # pad the sequence to the boxsize:
        npad = ((0, self._opts._windowlength-length), (0, 0))
        padded_seq_matrix = np.pad(oh_seq_matrix, pad_width=npad, mode='constant', constant_values=0)
        padded_seq = np.transpose(padded_seq_matrix)
        return padded_seq, label, garbage_label

    def _bytes_feature(self, value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _float_feature(self, value):
        return tf.train.Feature(bytes_list=tf.train.FloatList(value=[value]))

    def _int64_feature(self, value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def example_to_TFrecords(self, is_train, garbage_percentage=0.2, structure=True):
        include_garbage = False if garbage_percentage == 0 else True

        # determine how many files we need to write:
        if is_train:
            length_data_set = _count_lines(self._opts._traindata)
            batch_files_name = os.path.basename(self._opts._traindata) + 'train_batch_{}'.format(str(self._opts._windowlength))
            print(batch_files_name)
            in_csv = self.traindata

        else:
            length_data_set = _count_lines(self._opts._validdata)
            batch_files_name = os.path.basename(self._opts._validdata) + 'valid_batch_{}'.format(str(self._opts._windowlength))
            print(batch_files_name)
            in_csv = self.validdata

        files_to_write = np.int32(np.ceil(length_data_set*(1+garbage_percentage)*2
                                          / float(self.examples_per_file))) # write every thing twice

        for n in range(1, files_to_write+1):
            file_path = os.path.join(self._opts._batchesdir, batch_files_name) + '_' + str(n)
            self.writer = tf.python_io.TFRecordWriter(file_path)

            if structure:

                for i in range(self.examples_per_file):
                    if include_garbage and  i % int(1/garbage_percentage) == 0:
                        # print("garbage_seq")
                        seq_tensor, label, garbage_label = self.generate_garbage_sequence()
                        structure_label = np.zeros([self._opts._windowlength])

                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength), "%s" % str(seq_tensor.shape)
                        assert label.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = label.tostring()
                        garbage_label_raw = garbage_label.tostring()
                        structure_label_raw = structure_label.tostring()

                        example = tf.train.Example(
                            features=tf.train.Features(feature={
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'structure_depth': self._int64_feature(self._opts._structuredims),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                                'garbage_label_raw': self._bytes_feature(garbage_label_raw),
                                'structure_label_raw': self._bytes_feature(structure_label_raw),
                            }))
                        self.writer.write(example.SerializeToString())
                    else:
                        # print("validseq")
                        try:
                            oh_labels, seq_tensor, structure_label, _, _ = self._process_csv(in_csv)

                        except IndexError: # catches error from csv_decoder -> reopen the file:
                            in_csv.close()
                            if is_train:
                                self.traindata = open(self._opts._traindata, 'r')
                                in_csv = self.traindata
                            else:
                                self.validdata = open(self._opts._validdata, 'r')
                                in_csv = self.validdata
                            oh_labels, seq_tensor, structure_label, _, _ = self._process_csv(in_csv)

                        garbage_label = np.asarray([0]) # NOT garbage
                        garbage_label = np.expand_dims(garbage_label, axis=0)

                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength)
                        assert oh_labels.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = oh_labels.tostring()
                        garbage_label_raw = garbage_label.tostring()
                        structure_label_raw = structure_label.tostring()

                        example = tf.train.Example(
                            features=tf.train.Features(feature={
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'structure_depth': self._int64_feature(self._opts._structuredims),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                                'garbage_label_raw': self._bytes_feature(garbage_label_raw),
                                'structure_label_raw': self._bytes_feature(structure_label_raw),
                            }))
                        self.writer.write(example.SerializeToString())

            elif not structure:

                for i in range(self.examples_per_file):
                    if include_garbage and  i % int(1/garbage_percentage) == 0:
                        # print("garbage_seq")
                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength), "%s" % str(seq_tensor.shape)
                        assert label.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = label.tostring()

                        example = tf.train.Example(
                            features=tf.train.Features(feature={
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                            }))
                        self.writer.write(example.SerializeToString())
                    else:
                        # print("validseq")
                        try:
                            oh_labels, seq_tensor, _, _, _ = self._process_csv(in_csv)

                        except IndexError: # catches error from csv_decoder -> reopen the file:
                            in_csv.close()
                            if is_train:
                                self.traindata = open(self._opts._traindata, 'r')
                                in_csv = self.traindata
                            else:
                                self.validdata = open(self._opts._validdata, 'r')
                                in_csv = self.validdata
                            oh_labels, seq_tensor, _, _, _ = self._process_csv(in_csv)

                        assert seq_tensor.shape == (self._opts._depth, self._opts._windowlength)
                        assert oh_labels.shape == (1, self._opts._nclasses)
                        # convert the features to a raw string:
                        seq_raw = seq_tensor.tostring()
                        label_raw = oh_labels.tostring()

                        example = tf.train.Example(
                            features=tf.train.Features(feature={
                                'windowlength': self._int64_feature(self._opts._windowlength),
                                'depth': self._int64_feature(self._opts._depth),
                                'label_classes': self._int64_feature(self._opts._nclasses),
                                'seq_raw': self._bytes_feature(seq_raw),
                                'label_raw': self._bytes_feature(label_raw),
                            }))
                        self.writer.write(example.SerializeToString())
            self.writer.close()

    def embed_and_to_TFrecords(self, is_train):
        """
        Process the dataset to embedded sequences, and wirte the sequences as TF records.
        :return:
        """
        # ensure _batches_dir is correct:
        assert self._opts._batchesdir.endswith('embed/')

        # create batchesdir if not exists
        if not os.path.exists(self._opts._batchesdir):
            os.mkdir(self._opts._batchesdir)

        # collect the file for train/valid
        if is_train:
            length_data_set = _count_lines(self._opts._traindata)
            batch_files_name = os.path.basename(self._opts._traindata) + 'train_batch_{}'.format(str(self._opts._windowlength))
            print(batch_files_name)
            in_csv = self.traindata

        else:
            length_data_set = _count_lines(self._opts._validdata)
            batch_files_name = os.path.basename(self._opts._validdata) + 'valid_batch_{}'.format(str(self._opts._windowlength))
            print(batch_files_name)
            in_csv = self.validdata

        # "placeholder" for length:
        length_node = tf.placeholder(dtype=tf.int32)

        # load the embedding and construct a session to do the lookup!
        with tf.Session() as sess:
            with tf.variable_scope('kmer2vec') as vs:
                embedding = tf.get_variable('embedding',
                                                shape=[self._opts._nkmers,
                                                       self._opts._embeddingdim],
                                                trainable=False)
                #embedding_saver = tf.train.Saver({"w_out": embedding})
                embedding_saver = tf.train.Saver({"n_emb": embedding})
                embedding_saver.restore(sess, tf.train.latest_checkpoint(
                                                  self._opts._kmer2vec_embedding))

            with tf.variable_scope('process_sequence') as vs:
                sequence_node = tf.placeholder(tf.int32, shape=[self._opts._windowlength])

                # slice the sequence and extract the words to be embedded:
                #true_seq = sequence_node[:length_node]
                # look it up
                embedded_sequence = tf.transpose(tf.nn.embedding_lookup(embedding, sequence_node))
                embedded_sequence = tf.reshape(embedded_sequence, [self._opts._embeddingdim, self._opts._windowlength])

            # determine how many files we need to write:
            files_to_write = np.int32(np.ceil(length_data_set * 2 / float(self.examples_per_file))) # write every thing twice

            for n in range(1, files_to_write+1):
                file_path = os.path.join(self._opts._batchesdir, batch_files_name) + '_' + str(n)
                self.writer = tf.python_io.TFRecordWriter(file_path)

                for _ in range(self.examples_per_file):

                    try:
                        oh_labels, seq_tensor, _, length = self._process_csv(in_csv)

                    except IndexError: # catches error from csv_decoder -> reopen the file:
                        in_csv.close()
                        if is_train:
                            self.traindata = open(self._opts._traindata, 'r')
                            in_csv = self.traindata
                        else:
                            self.validdata = open(self._opts._validdata, 'r')
                            in_csv = self.validdata
                        oh_labels, seq_tensor, _, length = self._process_csv(in_csv)

                    feed_dict = {sequence_node: seq_tensor, length_node: length}

                    embedded_seq = sess.run(embedded_sequence, feed_dict=feed_dict) # will return a list
                    embedded_seq = np.asarray(embedded_seq)

                    self._embed_to_TFrecords(embedded_seq, oh_labels)
            self.writer.close()

    def _embed_to_TFrecords(self, embedded_seq, oh_label):

        assert embedded_seq.shape == (self._opts._embeddingdim, self._opts._windowlength)
        assert oh_label.shape == (1, self._opts._nclasses)

        #print(np.argmax(oh_label, axis=1))

        # convert the features to a raw string:
        seq_raw = embedded_seq.tostring()
        label_raw = oh_label.tostring()

        example = tf.train.Example(
            features=tf.train.Features(feature={
                'windowlength': self._int64_feature(self._opts._windowlength),
                'depth': self._int64_feature(self._opts._embeddingdim),
                'label_classes': self._int64_feature(self._opts._nclasses),
                'seq_raw': self._bytes_feature(seq_raw),
                'label_raw': self._bytes_feature(label_raw),
            }))

        self.writer.write(example.SerializeToString())

    def produce_train_valid(self):
        if self.mode.startswith('one_hot'):
            #self.example_to_TFrecords(is_train=True, garbage_percentage=0, structure=False)
            self.example_to_TFrecords(is_train=False, garbage_percentage=0, structure=False)

        elif self.mode.startswith('embed'):
            self.embed_and_to_TFrecords(is_train=True)
            #self.embed_and_to_TFrecords(is_train=False)


def plot_histogram(log_file, save_dir):
    count_dict = {}
    with open(log_file, "r") as in_fobj:
        for line in in_fobj:
            pred_labels = line.strip().split()
            for label in pred_labels:
                try:
                    count_dict[label] += 1
                except KeyError:
                    count_dict[label] = 0
    bars = [count_dict[label] for label in count_dict.keys()]
    labels = [label for label in count_dict.keys()]
    set_style("whitegrid")
    fig, ax = plt.subplots()
    ax = barplot(x=bars, y=labels)
    fig.save(os.path.join(save_dir, 'negative_test.png'))


def plot_simple_curve(x, y, title, legend, xname, yname, filename):
    plt.ioff()
    fig = plt.figure()
    plt.title(title)
    plt.plot(x, y, color="red", lw=2, label=legend)
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel(xname)
    plt.ylabel(yname)
    plt.legend(loc="lower right")
    plt.savefig(filename+".svg")
    plt.savefig(filename+".png")
    plt.close(fig)


def _count_lines(file_path):
    count = 0
    with open(file_path, "r") as fobj:
        for line in fobj:
            count += 1
    return count


def _add_var_summary(var, name, collection=None):
    """ attaches a lot of summaries to a given tensor"""
    with tf.name_scope(name):
        with tf.name_scope('summaries'):
            mean = tf.reduce_mean(var)
            tf.summary.scalar('mean', mean, collections=collection)
            with tf.name_scope('stddev'):
                stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
            tf.summary.scalar('stddev', stddev, collections=collection)
            tf.summary.scalar('max', tf.reduce_max(var), collections=collection)
            tf.summary.scalar('min', tf.reduce_min(var), collections=collection)
            tf.summary.histogram('histogram', var, collections=collection)


def _variable_on_cpu(name, shape, initializer, trainable):
    """ Helper function to get a variable stored on cpu"""
    with tf.device('/cpu:0'): #TODO will this work?
        dtype = tf.float32
        var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable)
    #dtf.add_to_collection('CPU', var)
    return var


def softmax(X, theta = 1.0, axis = None):
    """
    Compute the softmax of each element along an axis of X.

    Parameters
    ----------
    X: ND-Array. Probably should be floats.
    theta (optional): float parameter, used as a multiplier
        prior to exponentiation. Default = 1.0
    axis (optional): axis to compute values along. Default is the
        first non-singleton axis.

    Returns an array the same size as X. The result will sum to 1
    along the specified axis.
    """

    # make X at least 2d
    y = np.atleast_2d(X)

    # find axis
    if axis is None:
        axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)

    # multiply y against the theta parameter,
    y = y * float(theta)

    # subtract the max for numerical stability
    y = y - np.expand_dims(np.max(y, axis = axis), axis)

    # exponentiate y
    y = np.exp(y)

    # take the sum along the specified axis
    ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)

    # finally: divide elementwise
    p = y / ax_sum

    # flatten if X was 1D
    if len(X.shape) == 1: p = p.flatten()

    return p