#!/usr/bin/python # ************************** # * Author : baiyyang # * Email : baiyyang@163.com # * Description : # * create time : 2018/3/13下午5:05 # * file name : data_helpers.py import numpy as np import re import jieba import string from zhon.hanzi import punctuation import collections from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import OneHotEncoder def clean_str(s): """ Tokenization/string cleaning for all datasets excepts for SSI. :param s: :return: """ s = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", s) s = re.sub(r"\'s", " \'s", s) s = re.sub(r"\'ve", " \'ve", s) s = re.sub(r"n\'t", " n\'t", s) s = re.sub(r"\'re", " \'re", s) s = re.sub(r"\'d", " \'d", s) s = re.sub(r"\'ll", " \'ll", s) s = re.sub(r",", " , ", s) s = re.sub(r"!", " ! ", s) s = re.sub(r"\(", " \( ", s) s = re.sub(r"\)", " \) ", s) s = re.sub(r"\?", " \? ", s) s = re.sub(r"\s{2,}", " ", s) return s.strip().lower() def load_data_and_labels(positive_data_file, negative_data_file): """ Loads MR polarity data from files, splits the data into words and generates labels. Return split sentences and labels. :param positive_data_file: :param negative_data_file: :return: """ # Load data from files positive_examples = list(open(positive_data_file, 'r', encoding='utf-8').readlines()) positive_examples = [s.strip() for s in positive_examples] negative_examples = list(open(negative_data_file, 'r', encoding='utf-8').readlines()) negative_examples = [s.strip() for s in negative_examples] # Split by words x_text = positive_examples + negative_examples x_text = [clean_str(sent) for sent in x_text] # Generate labels positive_labels = [[0, 1] for _ in positive_examples] negative_labels = [[1, 0] for _ in negative_examples] y = np.concatenate([positive_labels, negative_labels], 0) return [x_text, y] def load_data_and_labels_chinese(train_data_file, test_data_file): """ 加载中文医疗疾病分类数据集 :param train_data_file: :param test_data_file: :return: """ words = [] contents = [] train_datas = [] test_datas = [] test_labels = [] labels = [] # 生成训练数据集 with open(train_data_file, 'r', encoding='utf-8') as f: for line in f: data, label = line.strip().split('\t') labels.append(label) # 分词 segments = [seg for seg in jieba.cut(data, cut_all=False)] segments_ = [seg.strip() for seg in segments if seg not in punctuation and seg not in string.punctuation] contents.append([seg_ for seg_ in segments_ if seg_ != '']) words.extend(segments_) words = [word for word in words if word != ''] count = [['UNK', -1]] count.extend(collections.Counter(words).most_common(9999)) word2id = {} for word, _ in count: word2id[word] = len(word2id) # id2word = dict(zip(word2id.values(), word2id.keys())) print('dictionary_size:', len(word2id)) sentence_max_length = max([len(content) for content in contents]) print('sentence_max_length:', sentence_max_length) for content in contents: train_data = [word2id[word] if word in word2id.keys() else word2id['UNK'] for word in content] train_data.extend([0] * (sentence_max_length - len(train_data))) train_datas.append(train_data) label_encoder = LabelEncoder() integer_encoded = label_encoder.fit_transform(np.array(labels)) onehot_encoder = OneHotEncoder(sparse=False) train_labels = onehot_encoder.fit_transform(integer_encoded.reshape(len(integer_encoded), 1)) print(train_labels.shape) # 生成测试数据集 labels = [] contents = [] with open(test_data_file, 'r', encoding='utf-8') as f: for line in f: data, label = line.strip().split('\t') labels.append(label) # 分词 segments = [segment for segment in jieba.cut(data, cut_all=False)] segments_ = [segment.strip() for segment in segments if segment not in punctuation and segment not in string.punctuation] contents.append([seg_ for seg_ in segments_ if seg_ != '']) for content in contents: test_data = [word2id[word] if word in word2id.keys() else word2id['UNK'] for word in content] if sentence_max_length > len(test_data): test_data.extend([0] * (sentence_max_length - len(test_data))) else: test_data = test_data[:sentence_max_length] test_datas.append(test_data) integer_encoded = label_encoder.fit_transform(np.array(labels)) onehot_encoder = OneHotEncoder(sparse=False) test_labels = onehot_encoder.fit_transform(integer_encoded.reshape(len(integer_encoded), 1)) return word2id, train_datas, train_labels, test_datas, test_labels def batch_iter(data, batch_size, num_epochs, shuffle=True): """ Generates a batch iterator for a dataset :param data: :param batch_size: :param num_epochs: :param shuffle: :return: """ data = np.array(data) data_size = len(data) num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1 for epoch in range(num_epochs): # Shuffle the data at each epoch if shuffle: shuffle_indices = np.random.permutation(np.arange(data_size)) shuffled_data = data[shuffle_indices] else: shuffled_data = data for batch_num in range(num_batches_per_epoch): start_index = batch_num * batch_size end_index = min((batch_num + 1) * batch_size, data_size) yield shuffled_data[start_index: end_index] if __name__ == '__main__': word2id, train_datas, train_labels, test_datas, test_labels = load_data_and_labels_chinese('data/train.txt', 'data/test.txt') print(len(train_datas), len(train_labels), len(test_datas), len(test_labels)) print(train_datas[:5]) print(train_labels[:5]) print(test_datas[:5]) print(test_labels[:5])