# -*- coding: utf-8 -*- # # Author: Taylor Smith <taylor.smith@alkaline-ml.com> # # Test random corner cases for both balancers from __future__ import division, absolute_import, division from sklearn.utils.multiclass import type_of_target from smrt.balance import base, smote_balance from nose.tools import assert_raises from sklearn.datasets import load_iris import numpy as np iris = load_iris() X, y = iris.data, iris.target def test_label_corner_cases(): # the current max classes is 100 (might change though). n_classes = base.MAX_N_CLASSES + 1 # create n_classes labels, append on itself so there are at least two of each # so sklearn will find it as a multi-class and not a continuous target labels = np.arange(n_classes) labels = np.concatenate([labels, labels]) # assert that it's multiclass and that we're getting the appropriate ValueError! y_type = type_of_target(labels) assert y_type == 'multiclass', y_type # create an X of random. Doesn't even matter. x = np.random.rand(labels.shape[0], 4) # try to balance, but it will fail because of the number of classes assert_raises(ValueError, smote_balance, x, labels) # now time for continuous... labels = np.linspace(0, 1000, x.shape[0]) # fails because improper y_type assert_raises(ValueError, smote_balance, x, labels) # perform a balancing operation with only one observation, and show that it will raise labels = np.zeros(x.shape[0]) labels[0] = 1 # this is the only one. y_type = type_of_target(labels) assert y_type == 'binary', y_type # fails because only one observation of one of the classes assert_raises(ValueError, smote_balance, x, labels) def test_smote_corner_cases(): # if n_neighbors is < 1... assert_raises(ValueError, smote_balance, X, y, n_neighbors=0) # show that a bad "strategy" is a ValueError assert_raises(ValueError, smote_balance, X, y, strategy='bad-input') # show that iris will not actually balance anything, since there is no majority class X_smote, y_smote = smote_balance(X, y) assert X_smote.shape[0] == y_smote.shape[0] == 150 # show that whether or not shuffle is used, it will not raise an index error: # likewise, show that models can be returned for shuffle in (True, False): for model in (True, False): output = smote_balance(X, y, shuffle=shuffle, return_estimators=model) if model: assert len(output) == 3