# -*- coding: utf-8 -*- # # Authors: Taylor Smith <taylor.smith@alkaline-ml.com> # Jason White <jason.m.white5@gmail.com> # # Commons and bases for the SMRT and SMOTE balancers from __future__ import absolute_import import numpy as np from ..utils import validate_float, NPDTYPE from sklearn.utils import column_or_1d from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_array MAX_N_CLASSES = 100 # max unique classes in y MIN_N_SAMPLES = 2 # min allowed ever. def _validate_X_y_ratio_classes(X, y, ratio): # validate the cheap stuff before copying arrays around... validate_float(ratio, 'balance_ratio') # validate arrays X = check_array(X, accept_sparse=False, dtype=NPDTYPE, ensure_2d=True, copy=True) y = check_array(y, accept_sparse=False, ensure_2d=False, dtype=None) y = column_or_1d(y, warn=False) # get n classes in y, ensure they are <= MAX_N_CLASSES, but first ensure these are actually # class labels and not floats or anything... y_type = type_of_target(y) supported_types = {'multiclass', 'binary'} if y_type not in supported_types: raise ValueError('balancers only support %r, but got %r' % ("(" + ', '.join(supported_types) + ")", y_type)) present_classes, counts = np.unique(y, return_counts=True) n_classes = len(present_classes) # ensure <= MAX_N_CLASSES if n_classes > MAX_N_CLASSES: raise ValueError('balancers currently only support a maximum of %i ' 'unique class labels, but %i were identified.' % (MAX_N_CLASSES, n_classes)) # get the majority class label, and its count: majority_count_idx = np.argmax(counts, axis=0) majority_label, majority_count = present_classes[majority_count_idx], counts[majority_count_idx] target_count = max(int(ratio * majority_count), 1) # define a min_n_samples based on the sample ratio to max_class # required = {target_count - counts[i] for i, v in enumerate(present_classes) if v != majority_label} # THIS WAS OUR ORIGINAL LOGIC: # * If there were any instances where the number of synthetic examples required for a class # outweighed the number that existed in the class to begin with, we would end up having to # potentially sample from the synthetic examples. We didn't want to have to do that. # # But it seems like a totally valid use-case. If we're detecting breast cancer, it might be a rare # event that needs lots of bolstering. We should allow that, even though we may discourage it. # if any counts < MIN_N_SAMPLES, raise: if any(i < MIN_N_SAMPLES for i in counts): raise ValueError('All label counts must be >= %i' % MIN_N_SAMPLES) return X, y, n_classes, present_classes, counts, majority_label, target_count