#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 19 21:27:21 2017

@author: quantumliu
"""
import re,pickle
import jieba
import numpy as np
import h5py
def pad_sequences(sequences, maxlen=None, dtype='int32',
                  padding='pre', truncating='pre', value=0.):
    lengths = [len(s) for s in sequences]

    nb_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x
def num_sub(text):
    r_date=r'((\d{4}|\d{2})(-|/|.)\d{1,2}\3\d{1,2})|(\d{4}年\d{1,2}月\d{1,2}日)|(\d{1,2}月\d{1,2}日)|(\d{1,2}日)'
    r_time=r'(([0-1]?[0-9])|([2][0-3])):([0-5]?[0-9])(:([0-5]?[0-9]))?|([1-24]\d时[0-60]\d分)|([1-24]\d时)'
    r_num=r'[-+]?[0-9]*\.?[0-9]+'
    return re.sub(r_num,'FLOAT',re.sub(r_time,'TIME',re.sub(r_date,'DATE',text)))
def get_text(news_list,key):
    return ''.join([re.sub(r'\s','',n.get(key,''))+'EOS\n' for n in news_list])
def cut(text,custom_words=['FLOAT','TIME','DATE','EOS']):
    jieba.enable_parallel(32)
    for word in custom_words:
        jieba.add_word(word)
    words=jieba.lcut(text)
    return words
def get_dic(ulist):
    return {j:i for i,j in enumerate(np.unique(np.array(ulist)))}
def get_inverse(words,udic,sp_ch='\n'):
    sentences=[s.strip('<$>').split('<$>') for s in '<$>'.join(words).strip(sp_ch).split(sp_ch)]
    inverse=[]
    for s in sentences:
        inverse.append([udic.get(w,0) for w in s])
    lens=[len(i) for i in inverse]
    maxlen=int(np.std(lens)+np.mean(lens))
    return pad_sequences(inverse,maxlen=maxlen,padding='post')    
if __name__ == '__main__':
    val_split=0.2
    with open('sina_news.pkl','rb')as f:
        news=pickle.load(f)
    ct=cut(num_sub(get_text(news,'title')))
    ca=cut(num_sub(get_text(news,'abstract')))
    uwords=ca+ct
    udic=get_dic(uwords)
    print('There are ',len(uwords),' words.\n',len(udic),' unique tokens.')
    udata=get_inverse(ca,udic)
    ulabel=get_inverse(ct,udic)
    nb_samples=udata.shape[0]
    print('There are '+str(nb_samples)+' samples')
    perm=np.random.permutation(nb_samples)
    nb_train=int(np.floor(nb_samples*(1-val_split)))
    train_data,train_label=udata[perm[:nb_train]],ulabel[perm[:nb_train]]
    val_data,val_label=udata[perm[nb_train:]],ulabel[perm[nb_train:]]
    with open('dic.pkl','wb') as f:
        pickle.dump(udic,f)
    with h5py.File('data_train.h5','w') as f:
        f.create_dataset('y',data=train_label)
        f.create_dataset('x',data=train_data)
    with h5py.File('data_val.h5','w') as f:
        f.create_dataset('y',data=val_label)
        f.create_dataset('x',data=val_data)
    print('There are ',train_data.shape[0],' samples for training.')
    print('There are ',val_data.shape[0],' samples for validation.')