# Author: Mathurin Massias <mathurin.massias@gmail.com> # License: BSD 3 clause import os from os.path import join as pjoin from pathlib import Path from bz2 import BZ2Decompressor import numpy as np from scipy import sparse from download import download from sklearn import preprocessing from sklearn.datasets import load_svmlight_file CELER_PATH = pjoin(str(Path.home()), 'celer_data') NAMES = {'rcv1_train': 'binary/rcv1_train.binary', 'news20': 'binary/news20.binary', 'finance': 'regression/log1p.E2006.train', 'kdda_train': 'binary/kdda', 'rcv1_topics_test': 'multilabel/rcv1_topics_test_2.svm', 'url': 'binary/url_combined', 'webspam': 'binary/webspam_wc_normalized_trigram.svm'} N_FEATURES = {'finance': 4272227, 'news20': 1355191, 'rcv1_train': 47236, 'kdda_train': 20216830, 'rcv1_topics_test': 47236, 'url': 3231961, 'webspam': 16609143} def download_libsvm(dataset, destination, replace=False): """Download a dataset from LIBSVM website.""" url = ("https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/" + NAMES[dataset]) path = download(url + '.bz2', destination, replace=replace) return path def get_X_y(dataset, compressed_path, multilabel, replace=False): """Load a LIBSVM dataset as sparse X and observation y/Y. If X and y already exists as npz and npy, they are not redownloaded unless replace=True.""" ext = '.npz' if multilabel else '.npy' y_path = pjoin(CELER_PATH, "%s_target%s" % (NAMES[dataset], ext)) X_path = pjoin(CELER_PATH, "%s_data.npz" % NAMES[dataset]) if replace or not os.path.isfile(y_path) or not os.path.isfile(X_path): tmp_path = pjoin(CELER_PATH, "%s" % NAMES[dataset]) decompressor = BZ2Decompressor() print("Decompressing...") with open(tmp_path, "wb") as f, open(compressed_path, "rb") as g: for data in iter(lambda: g.read(100 * 1024), b''): f.write(decompressor.decompress(data)) n_features_total = N_FEATURES[dataset] print("Loading svmlight file...") with open(tmp_path, 'rb') as f: X, y = load_svmlight_file( f, n_features_total, multilabel=multilabel) os.remove(tmp_path) X = sparse.csc_matrix(X) X.sort_indices() sparse.save_npz(X_path, X) if multilabel: indices = np.array([lab for labels in y for lab in labels]) indptr = np.cumsum([0] + [len(labels) for labels in y]) data = np.ones_like(indices) Y = sparse.csr_matrix((data, indices, indptr)) sparse.save_npz(y_path, Y) return X, Y else: np.save(y_path, y) else: X = sparse.load_npz(X_path) y = np.load(y_path) return X, y def load_libsvm(dataset, replace=False, normalize=True, min_nnz=3): """ Download a dataset from LIBSVM website. Parameters ---------- dataset : string Dataset name. Must be in celer.datasets.libsvm.NAMES.keys() replace : bool, default=False Whether to force download of dataset if already downloaded. normalize : bool If True, columns of X are set to unit norm. This may make little sense for a sparse matrix since centering is not performed. y is centered and set to unit norm if the dataset is a regression one. min_nnz: int, default=3 Columns of X with strictly less than min_nnz non-zero entries are discarded. Returns ------- X : scipy.sparse.csc_matrix Design matrix. y : 1D or 2D np.array Design vector or matrix (in multiclass setting) References ---------- https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ """ paths = [CELER_PATH, pjoin(CELER_PATH, 'regression'), pjoin(CELER_PATH, 'binary'), pjoin(CELER_PATH, 'multilabel')] for path in paths: if not os.path.exists(path): os.mkdir(path) if dataset not in NAMES: raise ValueError("Unsupported dataset %s" % dataset) multilabel = NAMES[dataset].split('/')[0] == 'multilabel' is_regression = NAMES[dataset].split('/')[0] == 'regression' print("Dataset: %s" % dataset) compressed_path = pjoin(CELER_PATH, "%s.bz2" % NAMES[dataset]) download_libsvm(dataset, compressed_path, replace=replace) X, y = get_X_y(dataset, compressed_path, multilabel, replace=replace) # preprocessing if min_nnz != 0: X = X[:, np.diff(X.indptr) >= min_nnz] if normalize: X = preprocessing.normalize(X, axis=0) if is_regression: y -= np.mean(y) y /= np.std(y) return X, y if __name__ == "__main__": for dataset in NAMES: load_libsvm(dataset, replace=False)