import numpy as np from sklearn import base from sklearn import utils class TopTermsClassifier(base.BaseEstimator, base.ClassifierMixin): def __init__(self, n_terms=10): self.n_terms = n_terms def fit(self, X, y=None, **fit_params): # scikit-learn checks X, y = utils.check_X_y(X, y, accept_sparse='csr', order='C') n_terms = min(self.n_terms, X.shape[1]) # Get a list of unique labels from y labels = np.unique(y) # Determine the n top terms per class self.top_terms_per_class_ = { c: set(np.argpartition(np.sum(X[y == c], axis=0), -n_terms)[-n_terms:]) for c in labels } # Return the classifier return self def _predict(self, x): # Find the terms in the document terms = set(np.where(x > 0)[0]) # Find the class that has the most top words in common with the document return max( self.top_terms_per_class_.keys(), key=lambda c: len(set.intersection(terms, self.top_terms_per_class_[c])) ) def predict(self, X): # scikit-learn checks X = utils.check_array(X, accept_sparse='csr', order='C') return np.array([self._predict(x) for x in X])