# -*- coding: utf-8 -*- """ cherry.base ~~~~~~~~~~~~ Base method for cherry :copyright: (c) 2018-2020 by Windson Yang :license: MIT License, see LICENSE for more details. """ import os import pickle import tarfile import hashlib import codecs import urllib import logging import numpy as np from collections import namedtuple from urllib.request import urlretrieve from .exceptions import * from .common import * from sklearn.feature_extraction._stop_words import ENGLISH_STOP_WORDS from sklearn.feature_extraction.text import CountVectorizer, \ TfidfVectorizer, HashingVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.linear_model import SGDClassifier from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.datasets import load_files CHERRY_DIR = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'cherry') DATA_DIR = os.path.join(os.getcwd(), 'datasets') __all__ = ['DATA_DIR', 'get_stop_words', 'load_data', 'write_file', 'load_all', 'load_cache', 'get_vectorizer_and_clf', 'get_tokenizer', 'get_vectorizer', 'get_clf'] def get_stop_words(language='English'): ''' There are several known issues in our provided ‘english’ stop word list. It does not aim to be a general, ‘one-size-fits-all’ solution as some tasks may require a more custom solution. See https://aclweb.org/anthology/W18-2502 for more details. TODO: add IDF after every stop word. ''' if language == 'English': return ENGLISH_STOP_WORDS try: return STOP_WORDS[language] except KeyError: error = 'Cherry didn\'t support {0} at this moment.'.format(language) raise NotSupportError(error) def load_all(model, language=None, preprocessing=None, categories=None, encoding=None, vectorizer=None, vectorizer_method=None, clf=None, clf_method=None, x_data=None, y_data=None): # If user didn't pass x_data and y_data, try to load data from local or remote if not os.path.exists(DATA_DIR): os.mkdir(DATA_DIR) if not (x_data and y_data): try: cache = load_data(model, categories=categories, encoding=encoding) except FilesNotFoundError: error = ('Please make sure your put the {0} data inside `datasets` ' 'folder or use model inside "email", "review" or "newsgroups".'.format(model)) raise FilesNotFoundError(error) if preprocessing: cache.data = [preprocessing(text) for text in cache.data] x_data, y_data = cache.data, cache.target vectorizer, clf = get_vectorizer_and_clf( language, vectorizer, clf, vectorizer_method, clf_method) return x_data, y_data, vectorizer, clf def load_data(model, categories=None, encoding=None): ''' Load data using `model` name ''' model_path = os.path.join(DATA_DIR, model) if os.path.exists(model_path): return _load_data_from_local( model, categories=categories, encoding=encoding) else: return _load_data_from_remote( model, categories=categories, encoding=encoding) def _load_data_from_local( model, categories=None, encoding=None): ''' 1. Find local cache files 2. If we can't find the cache files 3.1 Try to create cache files using data files inside `datasets`. 2.2 Raise error if create cache files failed. ''' model_path = os.path.join(DATA_DIR, model) cache_path = os.path.join(model_path, model + '.pkz') if os.path.exists(cache_path): try: with open(cache_path, 'rb') as f: compressed_content = f.read() uncompressed_content = codecs.decode( compressed_content, 'zlib_codec') return pickle.loads(uncompressed_content)['all'] except Exception as e: # Can't load cache files error = ('Can\'t load cached data from {0}. ' 'Please try again after delete cache files.'.format(model)) raise NotSupportError(error) cache = dict(all=load_files( model_path, categories=categories, encoding=encoding)) compressed_content = codecs.encode(pickle.dumps(cache), 'zlib_codec') with open(cache_path, 'wb') as f: f.write(compressed_content) return cache['all'] def _load_data_from_remote(model, categories=None, encoding=None): try: info = BUILD_IN_MODELS[model] except KeyError: error = ('{0} is not in BUILD_IN_MODELS.').format(model) raise FilesNotFoundError(error) # The original data can be found at: # https://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz meta_data_c = namedtuple('meta_data_c', ['filename', 'url', 'checksum', 'encoding']) # Create a nametuple meta_data = meta_data_c(filename=info[0], url=info[1], checksum=info[2], encoding=info[3]) _fetch_remote(meta_data, DATA_DIR) _decompress_data(meta_data.filename, DATA_DIR) return _load_data_from_local( model, categories=categories, encoding=info[3]) def _fetch_remote(remote, dirname=None): """ Function from sklearn Helper function to download a remote datasets into path Copy from sklearn.datasets.base """ file_path = (remote.filename if dirname is None else os.path.join(dirname, remote.filename)) print('Downloading data from {0}'.format(remote.url)) urlretrieve(remote.url, file_path) checksum = _sha256(file_path) if remote.checksum != checksum: raise IOError("{} has an SHA256 checksum ({}) " "differing from expected ({}), " "file may be corrupted.".format(file_path, checksum, remote.checksum)) return file_path def _sha256(path): """ Function from sklearn Calculate the sha256 hash of the file at path. """ sha256hash = hashlib.sha256() chunk_size = 8192 with open(path, "rb") as f: while True: buffer = f.read(chunk_size) if not buffer: break sha256hash.update(buffer) return sha256hash.hexdigest() def _decompress_data(filename, model_path): ''' Function from sklearn ''' file_path = os.path.join(model_path, filename) logging.debug("Decompressing %s", file_path) tarfile.open(file_path, "r:gz").extractall(path=model_path) os.remove(file_path) def _train_test_split(cache, test_size=0.1): data_lst = list() target = list() filenames = list() data = cache['all'] data_lst.extend(data.data) target.extend(data.target) filenames.extend(data.filenames) data.data = data_lst data.target = np.array(target) data.filenames = np.array(filenames) return train_test_split(data.data, data.target, test_size=test_size, random_state=0) def write_file(path, data): ''' Write data to path ''' with open(path, 'a+') as f: f.write(data) def write_cache(model, content, path): ''' Write cached file under model dir ''' cache_path = os.path.join(DATA_DIR, model + '/' + path) compressed_content = codecs.encode(pickle.dumps(content), 'zlib_codec') with open(cache_path, 'wb') as f: f.write(compressed_content) def load_cache(model, path): ''' Load cache data from file ''' cache_path = os.path.join(DATA_DIR, model + '/' + path) if os.path.exists(cache_path): try: with open(cache_path, 'rb') as f: compressed_content = f.read() uncompressed_content = codecs.decode( compressed_content, 'zlib_codec') return pickle.loads(uncompressed_content) except Exception as e: error = ( 'Can\'t load cached files.') raise CacheNotFoundError(error) else: error = ( 'Can\'t find cache files') raise CacheNotFoundError(error) def english_tokenizer_wrapper(text): from nltk.tokenize import word_tokenize return [t for t in word_tokenize(text) if len(t) > 1] def chinese_tokenizer_wrapper(text): import jieba return [t for t in jieba.cut(text) if len(t) > 1] def get_tokenizer(language): if language == 'English': return english_tokenizer_wrapper elif language == 'Chinese': return chinese_tokenizer_wrapper else: raise NotSupportError(( 'You need to specify tokenizer function ' + 'when the language is nor English or Chinese.')) def get_vectorizer_and_clf( language, vectorizer, clf, vectorizer_method, clf_method): if not vectorizer: vectorizer = get_vectorizer(language, vectorizer_method) if not clf: clf = get_clf(clf_method) return vectorizer, clf def get_vectorizer(language, vectorizer_method): mapping = { 'Count': CountVectorizer, 'Tfidf': TfidfVectorizer, 'Hashing': HashingVectorizer, } try: method = mapping[vectorizer_method] except KeyError: error = 'Please make sure vectorizer_method in "Count", "Tfidf" or "Hashing".' raise MethodNotFoundError(error) else: return method(tokenizer=get_tokenizer(language), stop_words=get_stop_words(language)) def get_clf(clf_method): mapping = { 'MNB': (MultinomialNB, {'alpha': 0.1}), 'SGD': (SGDClassifier, {'loss': 'hinge', 'penalty': 'l2', 'alpha': 1e-3, 'max_iter': 5, 'tol': None}), 'RandomForest': (RandomForestClassifier, {'max_depth': 5}), 'AdaBoost': (AdaBoostClassifier, {}), } try: method, parameters = mapping[clf_method] except KeyError: error = 'Please make sure clf_method in "MNB", "SGD", "RandomForest" or "AdaBoost".' raise MethodNotFoundError(error) return method(**parameters)