# -*- coding: utf-8 -*-
"""
@author: Yan Shao, yan.shao@lingfil.uu.se
"""
import codecs
import sys
import numpy as np
import random
import os
import math
from reader import get_gold

sys = reload(sys)
sys.setdefaultencoding('utf-8')

punc = ['!', ')', ',', '.', ';', ':', '?', '»', '...', '..', '....', '%', 'º', '²', '°', '¿', '¡', '(', '«',
        '"', '\'', '-', '。', '·', '।', '۔']


def pre_token(line):
    out = []
    for seg in line.split(' '):
        f_out = []
        b_out = []
        while len(seg) > 0 and (seg[0] in punc or ('0' <= seg[0] <= '9')):
            f_out.append(seg[0])
            seg = seg[1:]
        while len(seg) > 0 and (seg[-1] in punc or ('0' <= seg[-1] <= '9')):
            b_out = [seg[-1]] + b_out
            seg = seg[:-1]
        if len(seg) > 0:
            out += f_out + [seg] + b_out
        else:
            out += f_out + b_out
    return out


def get_chars(path, filelist, sea=False):
    char_set = {}
    out_char = codecs.open(path + '/chars.txt', 'w', encoding='utf-8')
    for i, file_name in enumerate(filelist):
        for line in codecs.open(path + '/' + file_name, 'rb', encoding='utf-8'):
            line = line.strip()
            if sea=='sea':
                line = pre_token(line)
            for ch in line:
                if ch in char_set:
                    if i == 0:
                        char_set[ch] += 1
                else:
                    char_set[ch] = 1
    for k, v in char_set.items():
        out_char.write(k + '\t' + str(v) + '\n')
    out_char.close()


def get_dicts(path, sent_seg, tag_scheme='BIES', crf=1):
    char2idx = {'<P>': 0, '<UNK>': 1, '<#>': 2}
    unk_chars_idx = []
    idx = 3
    for line in codecs.open(path + '/chars.txt', 'r', encoding='utf-8'):
        segs = line.split('\t')
        if len(segs[0].strip()) == 0:
            if ' ' not in char2idx:
                char2idx[' '] = idx
                idx += 1
        else:
            char2idx[segs[0]] = idx
            if int(segs[1]) == 1:
                unk_chars_idx.append(idx)
            idx += 1
    idx2char = {k: v for v, k in char2idx.items()}
    if tag_scheme == 'BI':
        if crf > 0:
            tag2idx = {'<P>': 0, 'B': 1, 'I': 2}
            idx = 3
        else:
            tag2idx = {'B': 0, 'I': 1}
            idx = 2
    else:
        if crf > 0:
            tag2idx = {'<P>': 0, 'B': 1, 'I': 2, 'E': 3, 'S': 4}
            idx = 5
        else:
            tag2idx = {'B': 0, 'I':1, 'E':2, 'S':3}
            idx = 4
    for line in codecs.open(path + '/tags.txt', 'r', encoding='utf-8'):
        line = line.strip()
        if line not in tag2idx:
            tag2idx[line] = idx
            idx += 1
    if sent_seg:
        tag2idx['T'] = idx
        tag2idx['U'] = idx + 1
    idx2tag = {k: v for v, k in tag2idx.items()}

    trans_dict = {}
    key = ''
    if os.path.isfile(path + '/dict.txt'):
        for line in codecs.open(path + '/dict.txt', 'r', encoding='utf-8'):
            line = line.strip()
            if len(line) > 0:
                segs = line.split('\t')
                if len(segs) == 1:
                    key = segs[0]
                    trans_dict[key] = None
                elif len(segs) == 2:
                    if trans_dict[key] is None:
                        trans_dict[key] = segs[0].replace(' ', '  ')

    return char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict


def ngrams(raw, gram, is_space):
    gram_set = {}
    li = gram/2
    ri = gram - li - 1
    p = '<PAD>'
    last_line = ''
    is_first = True
    for line in raw:
        for i in range(len(line)):
            if i - li < 0:
                if is_space != 'sea':
                    lp = p * (li - i) + line[:i]
                else:
                    lp = [p] * (li - i) + line[:i]
            else:
                lp = line[i - li:i]
            if i + ri + 1 > len(line):
                if is_space != 'sea':
                    rp = line[i:] + p*(i + ri + 1 - len(line))
                else:
                    rp = line[i:] + [p] * (i + ri + 1 - len(line))
            else:
                rp = line[i:i+ri+1]
            ch = lp + rp
            if is_space == 'sea':
                ch = '_'.join(ch)
            if ch in gram_set:
                gram_set[ch] += 1
            else:
                gram_set[ch] = 1
        if is_first:
            is_first = False
        else:
            if is_space is True:
                last_line += ' '
            start_idx = len(last_line) - ri
            if start_idx < 0:
                start_idx = 0
            end_idx = li + len(last_line)
            j_line = last_line + line
            for i in range(start_idx, end_idx):
                if i - li < 0:
                    if is_space != 'sea':
                        j_lp = p * (-i) + j_line[start_idx:i]
                    else:
                        j_lp = [p] * (-i) + j_line[start_idx:i]
                else:
                    j_lp = j_line[i - li:i]
                if i + ri + 1 > len(j_line):
                    if is_space != 'sea':
                        j_rp = j_line[i:end_idx] + p * (ri + i + 1 - len(j_line))
                    else:
                        j_rp = j_line[i:end_idx] + [p] * (ri + i + 1 - len(j_line))
                else:
                    j_rp = j_line[i:ri + 1 + i]
                j_ch = j_lp + j_rp
                if is_space == 'sea':
                    ch = '_'.join(j_ch)
                if ch in gram_set:
                    gram_set[ch] += 1
                else:
                    gram_set[ch] = 1
        last_line = line
    return gram_set


def get_ngrams(path, ng, is_space):
    raw = []
    for line in codecs.open(path + '/raw_train.txt', 'r', encoding='utf-8'):
        if is_space == 'sea':
            segs = pre_token(line.strip())
        else:
            segs = line.strip()
        raw.append(segs)
    if ng > 1:
        for i in range(2, ng + 1):
            out_gram = codecs.open(path + '/' + str(i) + 'gram.txt', 'w', encoding='utf-8')
            grams = ngrams(raw, i, is_space)
            for k, v in grams.items():
                out_gram.write(k + '\t' + str(v) + '\n')
            out_gram.close()


def read_ngrams(path, ng):
    ngs = []
    for i in range(2, ng + 1):
        ng = {}
        for line in codecs.open(path + '/' + str(i) + 'gram.txt', 'r', encoding='utf-8'):
            line = line.rstrip()
            segs = line.split('\t')
            while len(segs[0]) < i:
                segs[0] += ' '
            ng[segs[0]] = int(segs[1])
        ngs.append(ng)
    return ngs


def get_sample_embedding(path, emb, chars2idx):
    chars = chars2idx.keys()
    short_emb = emb[emb.index('/') + 1: emb.index('.')]
    emb_dic = {}
    valid_chars=[]
    for line in codecs.open(emb, 'rb', encoding='utf-8'):
        line = line.strip()
        sets = line.split(' ')
        emb_dic[sets[0]] = np.asarray(sets[1:], dtype='float32')
    fout = codecs.open(path + '/' + short_emb + '_sub.txt', 'w', encoding='utf-8')
    for ch in chars:
        p_line = ch
        if ch in emb_dic:
            valid_chars.append(ch)
            for emb in emb_dic[ch]:
                p_line += ' ' + unicode(emb)
            fout.write(p_line + '\n')
    fout.close()


def read_sample_embedding(path, short_emb, char2idx):
    emb_values = []
    valid_chars = []
    emb_dic={}
    for line in codecs.open(path + '/' + short_emb + '_sub.txt', 'rb', encoding='utf-8'):
        first_ch = line[0]
        line = line.rstrip()
        sets = line.split(' ')
        if first_ch == ' ':
            emb_dic[' '] = np.asarray(sets, dtype='float32')
        else:
            emb_dic[sets[0]] = np.asarray(sets[1:], dtype='float32')
    emb_dim = len(emb_dic.items()[0][1])
    for ch in char2idx.keys():
        if ch in emb_dic:
            emb_values.append(emb_dic[ch])
            valid_chars.append(ch)
        else:
            rand = np.random.uniform(-math.sqrt(float(3) / emb_dim), math.sqrt(float(3) / emb_dim), emb_dim)
            emb_values.append(np.asarray(rand, dtype='float32'))
    emb_dim = len(emb_values[0])
    return emb_dim, emb_values, valid_chars


def get_sent_raw(path, fname, is_space=True):
    long_line = ''
    for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
        line = line.strip()
        if is_space:
            long_line += ' ' + line
        else:
            long_line += line
    if is_space:
        long_line = long_line[1:]

    return long_line


def chop(line, ad_s, limit):
    out = []
    chopped = False
    while len(line) > 0:
        if chopped:
            s_line = line[:limit - 1]
            s_line = [ad_s] + s_line
        else:
            chopped = True
            s_line = line[:limit]
        out.append(s_line)
        line = line[limit - 10:]
        if len(line) < 10:
            line = ''
    while len(out) > 0 and len(out[-1]) < limit-1:
        out[-1].append(0)
    return out


def get_input_vec(path, fname, char2idx, tag2idx, limit=500, sent_seg=False, is_space=True, train_size=-1, ignore_space=False):
    ct = 0
    max_len = 0
    space_idx = None
    if is_space is True:
        assert ' ' in char2idx
        space_idx = char2idx[' ']
    x_indices = []
    y_indices = []
    s_count = 0
    l_count = 0
    x = []
    y = []

    n_sent = 0

    if sent_seg:
        for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
            line = line.strip()
            if len(line) == 0:
                ct = 0
            elif ct == 0:
                if is_space == 'sea':
                    line = pre_token(line)
                for ch in line:
                    if len(ch.strip()) == 0:
                        x.append(char2idx[' '])
                    elif ch in char2idx:
                        x.append(char2idx[ch])
                    else:
                        x.append(char2idx['<UNK>'])
                if is_space is True and not ignore_space:
                    x = [space_idx] + x
                x_indices += x
                x = []
                ct = 1
            elif ct == 1:
                for ch in line:
                    y.append(tag2idx[ch])
                if y[-1] == tag2idx['S']:
                    y[-1] = tag2idx['T']
                else:
                    y[-1] = tag2idx['U']
                if is_space is True and not ignore_space:
                    y = [tag2idx['X']] + y
                y_indices += y
                y = []
                n_sent += 1
            if 0 < train_size <= n_sent:
                break
        x_indices = chop(x_indices, char2idx['<#>'], limit)
        y_indices = chop(y_indices, tag2idx['I'], limit)
        max_len = limit
    else:
        for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
            line = line.strip()
            if len(line) == 0:
                ct = 0
            elif ct == 0:
                if is_space == 'sea':
                    line = pre_token(line)
                max_len = max(max_len, len(line))
                s_count += 1
                if len(line) > limit:
                    l_count += 1
                chopped = False
                while len(line) > 0:
                    s_line = line[:limit - 1]
                    line = line[limit - 10:]
                    if len(line) < 10:
                        line = ''
                    if not chopped:
                        chopped = True
                    else:
                        x.append(char2idx['<#>'])
                    for ch in s_line:
                        if len(ch.strip()) == 0:
                            x.append(char2idx[' '])
                        elif ch in char2idx:
                            x.append(char2idx[ch])
                        else:
                            x.append(char2idx['<UNK>'])
                    x_indices.append(x)
                    x = []
                ct = 1
            elif ct == 1:
                chopped = False
                while len(line) > 0:
                    s_line = line[:limit - 1]
                    line = line[limit - 10:]
                    if len(line) < 10:
                        line = ''
                    if not chopped:
                        chopped = True
                    else:
                        y.append(tag2idx['I'])
                    for ch in s_line:
                        y.append(tag2idx[ch])
                    y_indices.append(y)
                    y = []
                n_sent += 1
            if 0 < train_size <= n_sent:
                break
        max_len = min(max_len, limit)
        if l_count > 0:
            print '%d (out of %d) sentences are chopped.' % (l_count, s_count)
    return [x_indices], [y_indices], max_len


def get_input_vec_sent(path, fname, char2idx, win_size, is_space=True):
    pre_line = ''
    c_line = ''
    x = []
    y = []
    is_first = True
    for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
        line = line.strip()
        if is_space == 'sea':
            line = pre_token(line)
        start_idx = len(pre_line)
        if is_space is True:
            j_line = pre_line + ' ' + c_line + ' ' + line
            end_idx = start_idx + len(c_line) + 1
            if is_first:
                is_first = False
                j_line = j_line[1:]
                end_idx -= 1
        else:
            j_line = pre_line + c_line + line
            end_idx = start_idx + len(c_line)
        for i in range(start_idx, end_idx):
            indices = []
            for j in range(i - win_size, i + win_size + 1):
                if j < 0 or j >= len(j_line):
                    indices.append(char2idx['<P>'])
                else:
                    if j_line[j] in char2idx:
                        indices.append(char2idx[j_line[j]])
                    else:
                        indices.append(char2idx['<UNK>'])
            x.append(indices)
            if i == end_idx - 1:
                y.append(1)
            else:
                y.append(0)
        pre_line = c_line
        c_line = line
    if is_space is True:
        j_line = pre_line + ' ' + c_line
    else:
        j_line = pre_line + c_line
    start_idx = len(pre_line)
    end_idx = start_idx + len(c_line)
    for i in range(start_idx, end_idx):
        indices = []
        for j in range(i - win_size, i + win_size + 1):
            if j < 0 or j >= len(j_line):
                indices.append(char2idx['<P>'])
            else:
                if j_line[j] in char2idx:
                    indices.append(char2idx[j_line[j]])
                else:
                    indices.append(char2idx['<UNK>'])
        x.append(indices)
        if i == end_idx - 1:
            y.append(1)
        else:
            y.append(0)

    assert len(x) == len(y)
    return x, y


def get_input_vec_sent_raw(raws, char2idx, win_size):
    x = []
    for i in range(len(raws)):
        indices = []
        for j in range(i - win_size, i + win_size + 1):
            if j < 0 or j >= len(raws):
                indices.append(char2idx['<P>'])
            else:
                if raws[j] in char2idx:
                    indices.append(char2idx[raws[j]])
                else:
                    indices.append(char2idx['<UNK>'])
        x.append(indices)
    return x


def get_input_vec_raw(path, fname, char2idx, lines=None, limit=500, sent_seg=False, is_space=True, ignore_space=False):
    max_len = 0
    space_idx = None
    is_first = True
    if is_space is True:
        assert ' ' in char2idx
        space_idx = char2idx[' ']
    x_indices = []
    s_count = 0
    l_count = 0
    x = []
    if lines is None:
        assert fname is not None
        if path is None:
            real_path = fname
        else:
            real_path = path + '/' + fname
        lines = codecs.open(real_path, 'r', encoding='utf-8')
    if sent_seg:
        for line in lines:
            line = line.strip()
            if is_space == 'sea':
                line = pre_token(line)
            elif ignore_space:
                line = ''.join(line.split())
            for ch in line:
                if len(ch.strip()) == 0:
                    x.append(char2idx[' '])
                elif ch in char2idx:
                    x.append(char2idx[ch])
                else:
                    x.append(char2idx['<UNK>'])
            if is_space is True and not ignore_space:
                if is_first:
                    is_first = False
                else:
                    x = [space_idx] + x
            x_indices += x
            x = []
        x_indices = chop(x_indices, char2idx['<#>'], limit)
        max_len = limit
    else:
        for line in lines:
            line = line.strip()
            if len(line) > 0:
                if is_space == 'sea':
                    line = pre_token(line)
                elif ignore_space:
                    line = ''.join(line.split())
                max_len = max(max_len, len(line))
                s_count += 1

                for ch in line:
                    if len(ch.strip()) == 0:
                        x.append(char2idx[' '])
                    elif ch in char2idx:
                        x.append(char2idx[ch])
                    else:
                        x.append(char2idx['<UNK>'])

                if len(line) > limit:
                    l_count += 1
                    chop_x = chop(x, char2idx['<#>'], limit)
                    x_indices += chop_x
                else:
                    x_indices.append(x)
                x = []
        max_len = min(max_len, limit)
        if l_count > 0:
            print '%d (out of %d) sentences are chopped.' % (l_count, s_count)
    return [x_indices], max_len


def get_input_vec_tag(path, fname, char2idx, lines=None, limit=500, is_space=True):
    space_idx = None
    if is_space is True:
        assert ' ' in char2idx
        space_idx = char2idx[' ']
    x_indices = []
    out = []
    x = []
    is_first = True
    if lines is None:
        assert fname is not None
        if path is None:
            real_path = fname
        else:
            real_path = path + '/' + fname
        lines = codecs.open(real_path, 'r', encoding='utf-8')
    for line in lines:
        line = line.strip()
        if len(line) > 0:
            if is_space == 'sea':
                line = pre_token(line)
            if len(line) > 0:
                for ch in line:
                    if len(ch.strip()) == 0:
                        x.append(char2idx[' '])
                    elif ch in char2idx:
                        x.append(char2idx[ch])
                    else:
                        x.append(char2idx['<UNK>'])
                if is_space is True:
                    if is_first:
                        is_first = False
                    else:
                        x = [space_idx] + x
                x_indices += x
                x = []
            elif len(x_indices) > 0:
                x_indices = chop(x_indices, char2idx['<#>'], limit)
                out += x_indices
                x_indices = []
                is_first = True

    if len(x_indices) > 0:
        x_indices = chop(x_indices, char2idx['<#>'], limit)
        out += x_indices

    return [out], limit


def get_vecs(str, char2idx):
    out = []
    for ch in str:
        if ch in char2idx:
            out.append(char2idx[ch])
    return out


def get_dict_vec(trans_dict, char2idx):
    max_x, max_y = 0, 0
    x = []
    y = []
    for k, v in trans_dict.items():
        x.append(get_vecs(k, char2idx))
        y.append(get_vecs(v.replace('  ', ' '), char2idx) + [2])
        if len(k) > max_x:
            max_x = len(k)
        if len(v) > max_y:
            max_y = len(v)
    max_x += 5
    max_y += 5
    x = pad_zeros(x, max_x)
    y = pad_zeros(y, max_y)
    assert len(x) == len(y)
    num = len(x)
    xy = zip(x, y)
    random.shuffle(xy)
    xy = zip(*xy)
    t_x = xy[0][:int(num * 0.95)]
    t_y = xy[1][:int(num * 0.95)]
    v_x = xy[0][int(num * 0.95):]
    v_y = xy[1][int(num * 0.95):]
    return t_x, t_y, v_x, v_y


def get_ngram_dic(ng):
    gram_dics = []
    for i, gram in enumerate(ng):
        g_dic = {'<P>': 0, '<UNK>': 1, '<#>': 2}
        idx = 3
        for g in gram.keys():
            if gram[g] > 1:
                g_dic[g] = idx
            else:
                g_dic[g] = 1
            idx += 1
        gram_dics.append(g_dic)
    return gram_dics


def gram_vec(raw, dic, limit=500, sent_seg=False, is_space=True):
    out = []
    if is_space == 'sea':
        ngram = len(dic.keys()[0].split('_'))
    else:
        ngram = 0
        for k in dic.keys():
            if '<PAD>' not in k:
                ngram = len(k)
                break
    li = ngram/2
    ri = ngram - li - 1
    p = '<PAD>'
    indices = []
    is_first = True
    if sent_seg:
        last_line = ''
        for line in raw:
            for i in range(len(line)):
                if i - li < 0:
                    if is_space != 'sea':
                        lp = p * (li - i) + line[:i]
                    else:
                        lp = [p] * (li - i) + line[:i]
                else:
                    lp = line[i - li:i]
                if i + ri + 1 > len(line):
                    if is_space != 'sea':
                        rp = line[i:] + p * (i + ri + 1 - len(line))
                    else:
                        rp = line[i:] + [p] * (i + ri + 1 - len(line))
                else:
                    rp = line[i:i + ri + 1]
                ch = lp + rp
                if is_space == 'sea':
                    ch = '_'.join(ch)
                if ch in dic:
                    indices.append(dic[ch])
                else:
                    indices.append(dic['<UNK>'])
            if is_first:
                is_first = False
            else:
                start_idx = len(last_line) - ri
                if start_idx < 0:
                    start_idx = 0
                if is_space:
                    last_line += ' '
                j_line = last_line + line
                end_idx = len(last_line) + li
                j_indices = []
                for i in range(start_idx, end_idx):
                    if i - li < 0:
                        if is_space != 'sea':
                            j_lp = p * (-i) + j_line[start_idx:i]
                        else:
                            j_lp = [p] * (-i) + j_line[start_idx:i]
                    else:
                        j_lp = j_line[i - li:i]
                    if i + ri + 1 > len(j_line):
                        if is_space != 'sea':
                            j_rp = j_line[i:end_idx] + p * (ri + i + 1 - len(j_line))
                        else:
                            j_rp = j_line[i:end_idx] + [p] * (ri + i + 1 - len(j_line))
                    else:
                        j_rp = j_line[i:ri + 1 + i]
                    j_ch = j_lp + j_rp
                    if is_space == 'sea':
                        j_ch = '_'.join(j_ch)
                    if j_ch in dic:
                        j_indices.append(dic[j_ch])
                    else:
                        j_indices.append(dic['<UNK>'])
                if ri > 0:
                    out = out[: - ri] + j_indices[:ri]
                if is_space:
                    indices = j_indices[ - (li + 1):] + indices[li:]
                else:
                    indices = j_indices[ - li:] + indices[li:]
            out += indices
            indices = []
            last_line = line
        out = chop(out, dic['<#>'], limit)

    else:
        for line in raw:
            chopped = False
            while len(line) > 0:
                s_line = line[:limit - 1]
                line = line[limit - 10:]
                if len(line) < 10:
                    line = ''
                if not chopped:
                    chopped = True
                else:
                    indices.append(dic['<#>'])
                for i in range(len(s_line)):
                    if i - li < 0:
                        if is_space != 'sea':
                            lp = p * (li - i) + s_line[:i]
                        else:
                            lp = [p] * (li - i) + s_line[:i]
                    else:
                        lp = s_line[i - li:i]
                    if i + ri + 1 > len(s_line):
                        if is_space != 'sea':
                            rp = s_line[i:] + p * (i + ri + 1 - len(s_line))
                        else:
                            rp = s_line[i:] + [p] * (i + ri + 1 - len(s_line))
                    else:
                        rp = s_line[i:i + ri + 1]
                    ch = lp + rp
                    if is_space == 'sea':
                        ch = '_'.join(ch)
                    if ch in dic:
                        indices.append(dic[ch])
                    else:
                        indices.append(dic['<UNK>'])
                out.append(indices)
                indices = []
    return out


def get_gram_vec(path, fname, gram2index, lines=None, is_raw=False, limit=500, sent_seg=False, is_space=True, ignore_space=False):
    raw = []
    i = 0
    if lines is None:
        assert fname is not None
        if path is None:
            real_path = fname
        else:
            real_path = path + '/' + fname
        lines = codecs.open(real_path, 'r', encoding='utf-8')
    for line in lines:
        line = line.strip()
        if is_space == 'sea':
            line = pre_token(line)
        elif ignore_space:
            line = ''.join(line.split())
        if i == 0 or is_raw:
            raw.append(line)
            i += 1
        if len(line) > 0:
            i += 1
        else:
            i = 0
    out = []
    for g_dic in gram2index:
        out.append(gram_vec(raw, g_dic, limit, sent_seg, is_space))
    return out


def get_gram_vec_tag(path, fname, gram2index, lines=None, limit=500, is_space=True, ignore_space=False):
    raw = []
    out = [[] for _ in range(len(gram2index))]
    if lines is None:
        assert fname is not None
        if path is None:
            real_path = fname
        else:
            real_path = path + '/' + fname
        lines = codecs.open(real_path, 'r', encoding='utf-8')
    for line in lines:
        line = line.strip()
        if is_space == 'sea':
            line = pre_token(line)
        elif ignore_space:
            line = ''.join(line.split())
        if len(line) > 0:
            raw.append(line)
        else:
            for i, g_dic in enumerate(gram2index):
                out[i] += gram_vec(raw, g_dic, limit, True, is_space)
            raw = []
    if len(raw) > 0:
        for i, g_dic in enumerate(gram2index):
            out[i] += gram_vec(raw, g_dic, limit, True, is_space)
    return out


def read_vocab_tag(path):
    '''
    Read tags from index files and create dictionaries
    :param path:
    :return tag2idx, idx2tag
    '''
    tag2idx = {}
    for i, line in enumerate(codecs.open(path, 'rb', encoding='utf-8')):
        line = line.strip()
        tag2idx[line] = i
    idx2tag = {k: v for v, k in tag2idx.items()}
    return tag2idx, idx2tag


def get_tags(can, action='sep', tag_scheme='BIES', ignore_mwt=False):
    tags = []
    if tag_scheme == 'BI':
        for i in range(len(can)):
            if i == 0:
                if action == 'sep' or ignore_mwt:
                    tags.append('B')
                else:
                    tags.append('K')
            else:
                if action == 'sep' or ignore_mwt:
                    tags.append('I')
                else:
                    tags.append('Z')
    else:
        for i in range(len(can)):
            if len(can) == 1:
                if action == 'sep' or ignore_mwt:
                    tags.append('S')
                else:
                    tags.append('D')
            elif i == 0:
                if action == 'sep' or ignore_mwt:
                    tags.append('B')
                else:
                    tags.append('K')
            elif i == len(can) - 1:
                if action == 'sep' or ignore_mwt:
                    tags.append('E')
                else:
                    tags.append('J')
            else:
                if action == 'sep' or ignore_mwt:
                    tags.append('I')
                else:
                    tags.append('Z')
    return tags


def update_dict(trans_dic, can, trans):
    can = can.lower()
    if can not in trans_dic:
        trans_dic[can] = {}
    if trans not in trans_dic[can]:
        trans_dic[can][trans] = 1
    else:
        trans_dic[can][trans] += 1
    return trans_dic


def raw2tags(raw, sents, path, train_file, creat_dict=True, gold_path=None, ignore_space=False, reset=False,
             tag_scheme='BIES', ignore_mwt=False):
    wt = codecs.open(path + '/' + train_file, 'w', encoding='utf-8')
    if creat_dict and not ignore_mwt:
        wd = codecs.open(path + '/dict.txt', 'w', encoding='utf-8')
    wg = None
    if gold_path is not None:
        wg = codecs.open(path + '/' + gold_path, 'w', encoding='utf-8')
    wtg = None
    if reset or not os.path.isfile(path + '/tags.txt'):
        wtg = codecs.open(path + '/tags.txt', 'w', encoding='utf-8')
    final_dic = {}
    assert len(raw) == len(sents)
    invalid = 0
    s_tags = set()

    def matched(can, sent_l, tags, trans_dic):
        if '-' in sent_l[0][0]:
            nums = sent_l[0][0].split('-')
            count = int(nums[1]) - int(nums[0])
            sent_l.pop(0)
            segs = []
            while count >= 0:
                segs.append(sent_l[0][1])
                sent_l.pop(0)
                count -= 1
            j_seg = ''.join(segs)
            if j_seg == can:
                for seg in segs:
                    tags += get_tags(seg, tag_scheme=tag_scheme)
            elif can.replace('-', '') == j_seg:
                for c_split in can.split('-'):
                    tags += get_tags(c_split, tag_scheme=tag_scheme)
                    if tag_scheme == 'BI':
                        tags.append('I')
                    else:
                        tags.append('X')
                tags.pop()
            else:
                tags += get_tags(can, action='trans', tag_scheme=tag_scheme, ignore_mwt=ignore_mwt)
                if not ignore_mwt:
                    trans = ' '.join(segs)
                    trans_dic = update_dict(trans_dic, can, trans)
        else:
            tags += get_tags(can, tag_scheme=tag_scheme)
            sent_l.pop(0)

        return tags, trans_dic

    for raw_l, sent_l in zip(raw, sents):
        if ignore_space:
            raw_l = ''.join(raw_l.split())
        tags = []
        cans = raw_l.split(' ')
        trans_dic = {}
        gold = get_gold(sent_l, ignore_mwt=ignore_mwt)
        pre = ''
        for can in cans:
            t_can = can.strip()
            purged = len(can) - len(t_can)
            if purged > 0:
                can = t_can
            while purged > 0:
                if tag_scheme == 'BI':
                    tags.append('I')
                else:
                    tags.append('X')
                purged -= 1
            done = False
            if len(pre) > 0:
                can = pre + ' ' + can
            while not done:
                if can == sent_l[0][1]:
                    tags, trans_dic = matched(can, sent_l, tags, trans_dic)
                    done = True
                    pre = ''
                else:
                    if len(can) >= len(sent_l[0][1]):
                        s_l = len(sent_l[0][1])
                        s_can = can[:s_l]
                        if s_can != sent_l[0][1]:
                            done = True
                        tags, trans_dic = matched(s_can, sent_l, tags, trans_dic)
                        can = can[s_l:]
                        if len(can) == 0:
                            done = True
                            pre = ''
                    else:
                        pre = can
                        done = True
            if len(pre) == 0:
                if tag_scheme == 'BI':
                    tags.append('I')
                else:
                    tags.append('X')
        if len(tags) > 0:
            tags.pop()
        if len(tags) == len(raw_l):
            for tg in tags:
                s_tags.add(tg)
            wt.write(raw_l + '\n')
            wt.write(''.join(tags) + '\n')
            wt.write('\n')
            for key in trans_dic:
                if key not in final_dic:
                    final_dic[key] = trans_dic[key]
                else:
                    for tr in trans_dic[key]:
                        if tr in final_dic[key]:
                            final_dic[key][tr] += trans_dic[key][tr]
                        else:
                            final_dic[key][tr] = trans_dic[key][tr]
        else:
            invalid += 1
        if wg is not None:
            wg.write(gold + '\n')
    if wg is not None:
        wg.close()
    if wtg is not None:
        for stg in s_tags:
            wtg.write(stg + '\n')
        wtg.close()
    if creat_dict and not ignore_mwt:
        for key in final_dic:
            wd.write(key + '\n')
            s_dic = sorted(final_dic[key].items(), key=lambda x: x[1], reverse=True)
            for i in s_dic:
                wd.write(i[0] + '\t' + str(i[1]) + '\n')
            wd.write('\n')
    wt.close()
    print 'invalid sentences: ', invalid, len(raw)


def raw2tags_sea(raw, sents, path, train_file, gold_path=None, reset=False, tag_scheme='BIES'):
    wt = codecs.open(path + '/' + train_file, 'w', encoding='utf-8')
    wg = None
    if gold_path is not None:
        wg = codecs.open(path + '/' + gold_path, 'w', encoding='utf-8')
    assert len(raw) == len(sents)
    invalid = 0
    wtg = None
    if reset or not os.path.isfile(path + '/tags.txt'):
        wtg = codecs.open(path + '/tags.txt', 'w', encoding='utf-8')

    s_tags = set()

    def matched(can, sent_l, tags):
        segs = can.split(' ')
        sent_l.pop(0)
        if len(segs) == 1:
            tags.append('S')
        elif len(segs) > 1:
            if tag_scheme == 'BI':
                tags += ['B'] + ['I'] * (len(segs) - 1)
            else:
                mid_t = ['I'] * (len(segs) - 2)
                tags += ['B'] + mid_t + ['E']
        return tags

    for raw_l, sent_l in zip(raw, sents):
        tags = []
        cans = pre_token(raw_l)
        gold = get_gold(sent_l)
        pre = ''
        for can in cans:
            t_can = can.strip()
            purged = len(can) - len(t_can)
            if purged > 0:
                can = t_can
            while purged > 0:
                if tag_scheme == 'BI':
                    tags.append('I')
                else:
                    tags.append('X')
                purged -= 1
            if len(pre) > 0:
                can = pre + ' ' + can
            j_can = ''.join(can.split())
            if sent_l:
                j_sent = ''.join(sent_l[0][1].split())
            if j_can == j_sent:
                tags = matched(can, sent_l, tags)
                pre = ''
            else:
                assert len(j_can) < len(j_sent)
                pre = can
        if len(tags) == len(cans):
            for tg in tags:
                s_tags.add(tg)
            wt.write(raw_l + '\n')
            wt.write(''.join(tags) + '\n')
            wt.write('\n')
        else:
            invalid += 1
        if wg is not None:
            wg.write(gold + '\n')
    if wg is not None:
        wg.close()
    if wtg is not None:
        for stg in s_tags:
            wtg.write(stg + '\n')
        wtg.close()
    wt.close()

    print 'invalid sentences: ', invalid, len(raw)


def pad_zeros(l, max_len):
    padded = None
    if type(l) is list:
        padded = []
        for item in l:
            if len(item) <= max_len:
                padded.append(np.pad(item, (0, max_len - len(item)), 'constant', constant_values=0))
            else:
                padded.append(np.asarray(item[:max_len]))
        padded = np.asarray(padded)
    elif type(l) is dict:
        padded = {}
        for k, v in l.iteritems():
            padded[k] = [np.pad(item, (0, max_len - len(item)), 'constant', constant_values=0) for item in v]
    return padded

def unpad_zeros(l):
    out = []
    for tags in l:
        out.append([np.trim_zeros(line) for line in tags])
    return out


def buckets(x, y, size=50):
    assert len(x[0]) == len(y[0])
    num_inputs = len(x)
    samples = x + y
    num_items = len(samples)
    xy = zip(*samples)
    xy.sort(key=lambda i: len(i[0]))
    t_len = size
    idx = 0
    bucks = [[[]] for _ in range(num_items)]
    for item in xy:
        if len(item[0]) > t_len:
            if len(bucks[0][idx]) > 0:
                for buck in bucks:
                    buck.append([])
                idx += 1
            while len(item[0]) > t_len:
                t_len += size
        for i in range(num_items):
            #print item[i]
            bucks[i][idx].append(item[i])

    return bucks[:num_inputs], bucks[num_inputs:]


def pad_bucket(x, y, limit, bucket_len_c=None):
    assert len(x[0]) == len(y[0])
    num_inputs = len(x)
    num_tags = len(y)
    padded = [[] for _ in range(num_tags + num_inputs)]
    bucket_counts = []
    samples = x + y
    xy = zip(*samples)
    if bucket_len_c is None:
        bucket_len_c = []
        for i, item in enumerate(xy):
            max_len = len(item[0][-1])
            if i == len(xy) - 1:
                max_len = limit
            bucket_len_c.append(max_len)
            bucket_counts.append(len(item[0]))
            for idx in range(num_tags + num_inputs):
                padded[idx].append(pad_zeros(item[idx], max_len))
        print 'Number of buckets: ', len(bucket_len_c)
    else:
        idy = 0
        for item in xy:
            max_len = len(item[0][-1])
            while idy < len(bucket_len_c) and max_len > bucket_len_c[idy]:
                idy += 1
            bucket_counts.append(len(item[0]))
            if idy >= len(bucket_len_c):
                for idx in range(num_tags + num_inputs):
                    padded[idx].append(pad_zeros(item[idx], max_len))
                bucket_len_c.append(max_len)
            else:
                for idx in range(num_tags + num_inputs):
                    padded[idx].append(pad_zeros(item[idx], bucket_len_c[idy]))
    return padded[:num_inputs], padded[num_inputs:], bucket_len_c, bucket_counts


def get_real_batch(counts, b_size):
    real_batch_sizes = []
    for c in counts:
        if c < b_size:
            real_batch_sizes.append(c)
        else:
            real_batch_sizes.append(b_size)
    return real_batch_sizes


def merge_bucket(x):
    out = []
    for item in x:
        m = []
        for i in item:
            m += i
        out.append(m)
    return out


def decode_tags(idx, index2tags):
    out = []
    for id in idx:
        sents = []
        for line in id:
            sent = []
            for item in line:
                tag = index2tags[item]
                tag = tag.replace('E', 'I')
                tag = tag.replace('S', 'B')
                tag = tag.replace('J', 'Z')
                tag = tag.replace('D', 'K')
                sent.append(tag)
            sents.append(sent)
        out.append(sents)
    return out


def decode_chars(idx, idx2chars):
    out = []
    for line in idx:
        line = np.trim_zeros(line)
        out.append([idx2chars[item] for item in line])
    return out


def generate_output(chars, tags, trans_dict, transducer_dict=None, multi_tok=False, trans_type='mix'):
    out = []
    mult_out = []
    raw_out = []
    sent_seg = False

    def map_trans(c_trans, type=trans_type):
        if c_trans in trans_dict and (type == 'mix' or type == 'dict'):
            c_trans = trans_dict[c_trans]
        elif transducer_dict is not None and (type == 'mix' or type == 'trans'):
            c_trans = transducer_dict(c_trans)
        sp = c_trans.split()
        c_trans = '  '.join(sp)

        return c_trans

    def add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=False):
        c_trans = c_trans.strip()
        if len(c_trans) > 0:
            if trans:
                o_trans = c_trans
                c_trans = map_trans(c_trans)
                if multi_tok:
                    num_tr = len(c_trans.split('  '))
                    mt_p_line += '  ' + o_trans + '!#!' + str(num_tr) + '  ' + c_trans
            else:
                if multi_tok:
                    mt_p_line += '  ' + c_trans
            p_line += '  ' + c_trans
        return p_line, mt_p_line

    def split_sent(lines, s_str):
        for i in range(len(lines)):
            s_line = lines[i].strip()
            while s_line and s_line[-1] == s_str:
                s_line = s_line[:-1]
            sents = s_line.split(s_str)
            lines[i] = [sent.strip() for sent in sents]
        return lines

    for i, tag in enumerate(tags):
        assert len(chars) == len(tag)
        sub_out = []
        sub_raw_out = []
        multi_sub_out = []
        j_chars = []
        j_tags = []
        is_first = True
        for chs, tgs in zip(chars, tag):
            if chs[0] == '<#>':
                assert len(j_chars) > 0
                if is_first:
                    is_first = False
                    j_chars[-1] = j_chars[-1][:-5] + chs[6:]
                    j_tags[-1] = j_tags[-1][:-5] + tgs[6:]
                else:
                    j_chars[-1] = j_chars[-1][:-5] + chs[5:]
                    j_tags[-1] = j_tags[-1][:-5] + tgs[5:]
            else:
                j_chars.append(chs)
                j_tags.append(tgs)
                is_first = True
        chars = j_chars
        tag = j_tags
        for chs, tgs in zip(chars, tag):
            assert len(chs) == len(tgs)
            c_word = ''
            c_trans = ''
            p_line = ''
            r_line = ''
            mt_p_line = ''
            for ch, tg in zip(chs, tgs):
                r_line += ch
                if tg == 'I':
                    if len(c_trans) > 0:
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
                        c_trans = ''
                        c_word = ch
                    else:
                        c_word += ch
                elif tg == 'Z':
                    if len(c_word) > 0:
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
                        c_word = ''
                        c_trans = ch
                    else:
                        c_trans += ch
                elif tg == 'B':
                    if len(c_word) > 0:
                        c_word = c_word.strip()
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
                    elif len(c_trans) > 0:
                        c_trans = c_trans.strip()
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
                        c_trans = ''
                    c_word = ch
                elif tg == 'K':
                    if len(c_word) > 0:
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
                        c_word = ''
                    elif len(c_trans) > 0:
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
                    c_trans = ch
                elif tg == 'T':
                    sent_seg = True
                    if len(c_word) > 0:
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
                        c_word = ''
                    elif len(c_trans) > 0:
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
                        c_trans = ''
                    p_line += '  ' + ch + '<SENT>'
                    if multi_tok:
                        mt_p_line += '  ' + ch + '<SENT>'
                    r_line += '<SENT>'
                elif tg == 'U':
                    sent_seg = True
                    if len(c_word) > 0:
                        c_word += ch
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
                        c_word = ''
                    elif len(c_trans) > 0:
                        c_trans += ch
                        p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
                        c_trans = ''
                    elif len(ch.strip()) > 0:
                        p_line += ch
                        if multi_tok:
                            mt_p_line += ch
                    p_line += '<SENT>'
                    if multi_tok:
                        mt_p_line += '<SENT>'
                    r_line += '<SENT>'
                elif tg == 'X' and len(ch.strip()) > 0:
                    if len(c_word) > 0:
                        c_word += ch
                    elif len(c_trans) > 0:
                        c_trans += ch
                    else:
                        c_word = ch
                elif len(ch.strip()) > 0:
                    if len(c_word) > 0:
                        c_word += '  ' + ch
                    elif len(c_trans) > 0:
                        c_trans += '  ' + ch
                    else:
                        c_word = ch
            if len(c_word) > 0:
                c_word = c_word.strip()
                p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
            elif len(c_trans) > 0:
                c_trans = c_trans.strip()
                p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
            sub_out.append(p_line.strip())
            sub_raw_out.append(r_line.strip())
            if multi_tok:
                multi_sub_out.append(mt_p_line.strip())
        out.append(sub_out)
        raw_out.append(sub_raw_out)
        if multi_tok:
            mult_out.append(multi_sub_out)
    out[0][-1].rstrip('<SENT>')
    raw_out[0][-1].rstrip('<SENT>')
    if sent_seg:
        out = split_sent(out[0], '<SENT>')
        raw_out = split_sent(raw_out[0], '<SENT>')
    if multi_tok:
        mult_out[0][-1].rstrip('<SENT>')
        if sent_seg:
            mult_out = split_sent(mult_out[0], '<SENT>')
        return out, raw_out, mult_out
    else:
        return out, raw_out


def generate_output_sea(chars, tags):
    out = []
    raw_out = []
    sent_seg = False

    def split_sent(lines, s_str):
        for i in range(len(lines)):
            s_line = lines[i].strip()
            while s_line and s_line[-1] == s_str:
                s_line = s_line[:-1]
            sents = s_line.split(s_str)
            lines[i] = [sent.strip() for sent in sents]
        return lines

    for i, tag in enumerate(tags):
        assert len(chars) == len(tag)
        sub_out = []
        sub_raw_out = []
        j_chars = []
        j_tags = []
        is_first = True
        for chs, tgs in zip(chars, tag):
            if chs[0] == '<#>':
                assert len(j_chars) > 0
                if is_first:
                    is_first = False
                    j_chars[-1] = j_chars[-1][:-5] + chs[6:]
                    j_tags[-1] = j_tags[-1][:-5] + tgs[6:]
                else:
                    j_chars[-1] = j_chars[-1][:-5] + chs[5:]
                    j_tags[-1] = j_tags[-1][:-5] + tgs[5:]
            else:
                j_chars.append(chs)
                j_tags.append(tgs)
                is_first = True
        chars = j_chars
        tag = j_tags
        for chs, tgs in zip(chars, tag):
            assert len(chs) == len(tgs)
            p_line = ''
            r_line = ''
            for ch, tg in zip(chs, tgs):
                r_line += ' ' + ch
                if tg == 'I':
                    if ch == '.' or (ch >= '0' and ch <= '9'):
                        p_line += ch
                    else:
                        p_line += ' ' + ch
                elif tg == 'B':
                    p_line += '  ' + ch
                elif tg == 'T':
                    sent_seg = True
                    p_line += '  ' + ch + '<SENT>'
                    r_line += '<SENT>'
                elif tg == 'U':
                    sent_seg = True
                    p_line += ch + '<SENT>'
                    r_line += '<SENT>'
                elif len(ch.strip()) > 0:
                    p_line += '  ' + ch
            sub_out.append(p_line.strip())
            sub_raw_out.append(r_line.strip())
        out.append(sub_out)
        raw_out.append(sub_raw_out)
    out[0][-1].rstrip('<SENT>')
    raw_out[0][-1].rstrip('<SENT>')
    if sent_seg:
        out = split_sent(out[0], '<SENT>')
        raw_out = split_sent(raw_out[0], '<SENT>')
    return out, raw_out


def trim_output(out, length):
    assert len(out) == len(length)
    trimmed_out = []
    for item, l in zip(out, length):
        trimmed_out.append(item[:l])
    return trimmed_out


def generate_trans_out(x, idx2char):
    out = ''
    for i in x:
        if i == 3:
            out += ' '
        elif i in idx2char:
            out += idx2char[i]
    if '<#>' in out:
        out = out[:out.index('<#>')]
    out = out.replace('   ', ' ')
    out = out.replace('  ', ' ')
    return out


def generate_sent_out(raw, predictions):
    out = []
    line = ''
    assert len(raw) == len(predictions)
    for ch, tag in zip(raw, predictions):
        line += ch
        if tag == 1:
            line = line.strip()
            out.append(line)
            line = ''
    if len(line) > 0:
        line = line.strip()
        out.append(line)
    return out


def viterbi(max_scores, max_scores_pre, length, batch_size):
    best_paths = []
    for m in range(batch_size):
        path = []
        last_max_node = np.argmax(max_scores[m][length[m] - 1])
        path.append(last_max_node)
        for t in range(1, length[m])[::-1]:
            last_max_node = max_scores_pre[m][t][last_max_node]
            path.append(last_max_node)
        path = path[::-1]
        best_paths.append(path)
    return best_paths


def get_new_chars(path, char2idx, is_space):
    new_chars = set()
    for line in codecs.open(path, 'rb', encoding='utf-8'):
        line = line.strip()
        if is_space == 'sea':
            line = pre_token(line)
        for ch in line:
            if ch not in char2idx:
                new_chars.add(ch)
    return new_chars


def get_valid_chars(chars, emb_path):
    valid_chars = []
    total = []
    for line in codecs.open(emb_path, 'rb', encoding='utf-8'):
        line = line.strip()
        sets = line.split(' ')
        total.append(sets[0])
    for ch in chars:
        if ch in total:
            valid_chars.append(ch)
    return valid_chars


def get_new_embeddings(new_chars, emb_dim, emb_path):
    assert os.path.isfile(emb_path)
    emb = {}
    new_emb = []
    for line in codecs.open(emb_path, 'rb', encoding='utf-8'):
        line = line.strip()
        sets = line.split(' ')
        emb[sets[0]] = np.asarray(sets[1:], dtype='float32')
    if '<UNK>' not in emb:
        unk = np.random.uniform(-math.sqrt(float(3) / emb_dim), math.sqrt(float(3) / emb_dim), emb_dim)
        emb['<UNK>'] = np.asarray(unk, dtype='float32')
    for ch in new_chars:
        if ch in emb:
            new_emb.append(emb[ch])
        else:
            new_emb.append(emb['<UNK>'])
    return new_emb


def update_char_dict(char2idx, new_chars, unk_chars_idx, valid_chars=None):
    l_quos = ['"', '«', '„']
    r_quos = ['"', '»', '“']
    l_quos = [unicode(ch) for ch in l_quos]
    r_quos = [unicode(ch) for ch in r_quos]
    sub_dict = {}
    old_chars = char2idx.keys()
    dim = len(char2idx) + 10
    if valid_chars is not None:
        for ch in valid_chars:
            if char2idx[ch] in unk_chars_idx:
                unk_chars_idx.remove(ch)
    for char in new_chars:
        if char not in char2idx and len(char.strip()) > 0:
            char2idx[char] = dim
            if valid_chars is None or char not in valid_chars:
                unk_chars_idx.append(dim)
            dim += 1
    idx2char = {k: v for v, k in char2idx.items()}
    for ch in new_chars:
        if ch in l_quos:
            for l_ch in l_quos:
                if l_ch in old_chars:
                    sub_dict[char2idx[ch]] = char2idx[l_ch]
                    if char2idx[ch] in unk_chars_idx:
                        unk_chars_idx.remove(char2idx[ch])
                    break
        elif ch in r_quos:
            for r_ch in r_quos:
                if r_ch in old_chars:
                    sub_dict[char2idx[ch]] = char2idx[r_ch]
                    if char2idx[ch] in unk_chars_idx:
                        unk_chars_idx.remove(char2idx[ch])
                    break
    return char2idx, idx2char, unk_chars_idx, sub_dict


def get_new_grams(path, gram2idx, is_raw=False, is_space=True):
    raw = []
    i = 0
    for line in codecs.open(path, 'rb', encoding='utf-8'):
        line = line.strip()
        if is_space == 'sea':
            line = pre_token(line)
        if i == 0 or is_raw:
            raw.append(line)
            i += 1
        if len(line) > 0:
            i += 1
        else:
            i = 0
    new_grams = []
    for g_dic in gram2idx:
        new_g = []
        if is_space == 'sea':
            n = len(g_dic.keys()[0].split('_'))
        else:
            n = 0
            for k in g_dic.keys():
                if '<PAD>' not in k:
                    n = len(k)
                    break
        grams = ngrams(raw, n, is_space)
        for g in grams:
            if g not in g_dic:
                new_g.append(g)
        new_grams.append(new_g)
    return new_grams


def printer(raw, tagged, multi_out, outpath, sent_seg, form='conll'):
    assert len(tagged) == len(multi_out)
    validator(raw, multi_out)
    wt = codecs.open(outpath, 'w', encoding='utf-8')
    if form == 'conll':
        if not sent_seg:
            for raw_t, tagged_t, multi_t in zip(raw, tagged, multi_out):
                if len(multi_t) > 0:
                    wt.write('#sent_raw: ' + raw_t + '\n')
                    wt.write('#sent_tok: ' + tagged_t + '\n')
                    idx = 1
                    tgs = multi_t.split('  ')
                    pl = ''
                    for _ in range(8):
                        pl += '\t' + '_'
                    for tg in tgs:
                        if '!#!' in tg:
                            segs = tg.split('!#!')
                            wt.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + segs[0] + pl + '\n')
                        else:
                            wt.write(str(idx) + '\t' + tg + pl + '\n')
                            idx += 1
                    wt.write('\n')
        else:
            for tagged_t, multi_t in zip(tagged, multi_out):
                if len(tagged_t.strip()) > 0:
                    wt.write('#sent_tok: '+ tagged_t + '\n')
                    idx = 1
                    tgs = multi_t.split('  ')
                    pl = ''
                    for _ in range(8):
                        pl += '\t' + '_'
                    for tg in tgs:
                        if '!#!' in tg:
                            segs = tg.split('!#!')
                            wt.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + segs[0] + pl + '\n')
                        else:
                            wt.write(str(idx) + '\t' + tg + pl + '\n')
                            idx += 1
                    wt.write('\n')
    else:
        for tg in tagged:
            wt.write(tg + '\n')
    wt.close()


def biased_out(prediction, bias):
    out = []
    b_pres = []
    for pre in prediction:
        b_pres.append(pre[:,0] - pre[:,1])
    props = np.concatenate(b_pres)
    props = np.sort(props)[::-1]
    idx = int(bias*len(props))
    if idx == len(props):
        idx -= 1
    th = props[idx]
    print 'threshold: ', th, 1 / (1 + np.exp(-th))
    for pre in b_pres:
        pre[pre >= th] = 0
        pre[pre != 0] = 1
        out.append(pre)
    return out


def to_one_hot(y, nb_classes=None):
    '''Convert class vector (integers from 0 to nb_classes) to binary class matrix, for use with categorical_crossentropy.
    # Arguments
        y: class vector to be converted into a matrix
        nb_classes: total number of classes
    # Returns
        A binary matrix representation of the input.
    '''
    if not nb_classes:
        nb_classes = np.max(y)+1
    Y = np.zeros((len(y), nb_classes))
    for i in range(len(y)):
        Y[i, y[i]] = 1.
    return Y


def validator(raw, generated):
    raw_l = ''.join(raw)
    raw_l = ''.join(raw_l.split())
    for g in generated:
        g_tokens = g.split('  ')
        j = 0
        while j < len(g_tokens):
            if '!#!' in g_tokens[j]:
                segs = g_tokens[j].split('!#!')
                c_t = int(segs[1])
                r_seg = ''.join(segs[0].split())
                l_w = len(r_seg)
                if r_seg == raw_l[:l_w]:
                    raw_l = raw_l[l_w:]
                    raw_l = raw_l.strip()
                else:
                    raise Exception('Error: unmatch...')
                j += c_t
            else:
                r_seg = ''.join(g_tokens[j].split())
                l_w = len(r_seg)
                if r_seg == raw_l[:l_w]:
                    raw_l = raw_l[l_w:]
                    raw_l = raw_l.strip()
                else:
                    print r_seg
                    print raw_l[:l_w]
                    print ''
                    raise Exception('Error: unmatch...')
            j += 1


def mlp_post(raw, prediction, is_space=False, form='mlp1'):
    assert len(raw) == len(prediction)
    out = []
    for r_l, p_l in zip(raw, prediction):
        st = ''
        rtokens = r_l.split()
        ptokens = p_l.split('  ')
        purged = []
        for pt in ptokens:
            purged.append(pt.strip())
        ptokens = purged
        ptokens_str = ''.join(ptokens)
        assert ''.join(rtokens) == ''.join(ptokens_str.split())
        if form == 'mlp1':
            if is_space == 'sea':
                for p_t in ptokens:
                    st += p_t.replace(' ', '_') + ' '
            else:
                while rtokens and ptokens:
                    if rtokens[0] == ptokens[0]:
                        st += ptokens[0] + ' '
                        rtokens.pop(0)
                        ptokens.pop(0)
                    else:
                        if len(rtokens[0]) <= len(ptokens[0]):
                            assert ptokens[0][:len(rtokens[0])] == rtokens[0]
                            st += rtokens[0] + ' '
                            ptokens[0] = ptokens[0][len(rtokens[0]):].strip()
                            rtokens.pop(0)
                        else:
                            can = ''
                            while can != rtokens[0] and ptokens:
                                can += ptokens[0]
                                st += ptokens[0] + '\\\\'
                                ptokens.pop(0)
                            st = st[:-2] + ' '
                            rtokens.pop(0)
        else:
            for p_t in ptokens:
                st += p_t + ' '
        out.append(st.strip())
    return out