Python sklearn.model_selection.StratifiedShuffleSplit() Examples

The following are 30 code examples for showing how to use sklearn.model_selection.StratifiedShuffleSplit(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module sklearn.model_selection , or try the search function .

Example 1
Project: cwcf   Author: jaromiru   File: hpc_svm.py    License: MIT License 6 votes vote down vote up
def get_full_rbf_svm_clf(train_x, train_y, c_range=None, gamma_range=None):
		param_grid = dict(gamma=gamma_range, C=c_range)
		cv = StratifiedShuffleSplit(n_splits=2, test_size=0.2, random_state=42)
		grid = GridSearchCV(SVC(cache_size=1024), param_grid=param_grid, cv=cv, n_jobs=14, verbose=10)
		grid.fit(train_x, train_y)
		
		print("The best parameters are %s with a score of %0.2f" % (grid.best_params_, grid.best_score_))
		
		scores = grid.cv_results_['mean_test_score'].reshape(len(c_range), len(gamma_range))
		print("Scores:")
		print(scores)
		
		print("c_range:", c_range)
		print("gamma_range:", gamma_range)

		c_best = grid.best_params_['C']
		gamma_best = grid.best_params_['gamma']

		clf = SVC(C=c_best, gamma=gamma_best, verbose=True)
		return clf

#---------------- 
Example 2
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_2d_y():
    # smoke test for 2d y and multi-label
    n_samples = 30
    rng = np.random.RandomState(1)
    X = rng.randint(0, 3, size=(n_samples, 2))
    y = rng.randint(0, 3, size=(n_samples,))
    y_2d = y.reshape(-1, 1)
    y_multilabel = rng.randint(0, 2, size=(n_samples, 3))
    groups = rng.randint(0, 3, size=(n_samples,))
    splitters = [LeaveOneOut(), LeavePOut(p=2), KFold(), StratifiedKFold(),
                 RepeatedKFold(), RepeatedStratifiedKFold(),
                 ShuffleSplit(), StratifiedShuffleSplit(test_size=.5),
                 GroupShuffleSplit(), LeaveOneGroupOut(),
                 LeavePGroupsOut(n_groups=2), GroupKFold(), TimeSeriesSplit(),
                 PredefinedSplit(test_fold=groups)]
    for splitter in splitters:
        list(splitter.split(X, y, groups))
        list(splitter.split(X, y_2d, groups))
        try:
            list(splitter.split(X, y_multilabel, groups))
        except ValueError as e:
            allowed_target_types = ('binary', 'multiclass')
            msg = "Supported target types are: {}. Got 'multilabel".format(
                allowed_target_types)
            assert msg in str(e) 
Example 3
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_stratified_shuffle_split_init():
    X = np.arange(7)
    y = np.asarray([0, 1, 1, 1, 2, 2, 2])
    # Check that error is raised if there is a class with only one sample
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(3, 0.2).split(X, y))

    # Check that error is raised if the test set size is smaller than n_classes
    assert_raises(ValueError, next, StratifiedShuffleSplit(3, 2).split(X, y))
    # Check that error is raised if the train set size is smaller than
    # n_classes
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(3, 3, 2).split(X, y))

    X = np.arange(9)
    y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])

    # Train size or test size too small
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(train_size=2).split(X, y))
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(test_size=2).split(X, y)) 
Example 4
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_stratified_shuffle_split_multilabel():
    # fix for issue 9037
    for y in [np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
              np.array([[0, 1], [1, 1], [1, 1], [0, 1]])]:
        X = np.ones_like(y)
        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
        train, test = next(sss.split(X=X, y=y))
        y_train = y[train]
        y_test = y[test]

        # no overlap
        assert_array_equal(np.intersect1d(train, test), [])

        # complete partition
        assert_array_equal(np.union1d(train, test), np.arange(len(y)))

        # correct stratification of entire rows
        # (by design, here y[:, 0] uniquely determines the entire row of y)
        expected_ratio = np.mean(y[:, 0])
        assert_equal(expected_ratio, np.mean(y_train[:, 0]))
        assert_equal(expected_ratio, np.mean(y_test[:, 0])) 
Example 5
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_stratified_shuffle_split_multilabel_many_labels():
    # fix in PR #9922: for multilabel data with > 1000 labels, str(row)
    # truncates with an ellipsis for elements in positions 4 through
    # len(row) - 4, so labels were not being correctly split using the powerset
    # method for transforming a multilabel problem to a multiclass one; this
    # test checks that this problem is fixed.
    row_with_many_zeros = [1, 0, 1] + [0] * 1000 + [1, 0, 1]
    row_with_many_ones = [1, 0, 1] + [1] * 1000 + [1, 0, 1]
    y = np.array([row_with_many_zeros] * 10 + [row_with_many_ones] * 100)
    X = np.ones_like(y)

    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
    train, test = next(sss.split(X=X, y=y))
    y_train = y[train]
    y_test = y[test]

    # correct stratification of entire rows
    # (by design, here y[:, 4] uniquely determines the entire row of y)
    expected_ratio = np.mean(y[:, 4])
    assert_equal(expected_ratio, np.mean(y_train[:, 4]))
    assert_equal(expected_ratio, np.mean(y_test[:, 4])) 
Example 6
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_search.py    License: MIT License 6 votes vote down vote up
def test_grid_search_groups():
    # Check if ValueError (when groups is None) propagates to GridSearchCV
    # And also check if groups is correctly passed to the cv object
    rng = np.random.RandomState(0)

    X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
    groups = rng.randint(0, 3, 15)

    clf = LinearSVC(random_state=0)
    grid = {'C': [1]}

    group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2), GroupKFold(),
                 GroupShuffleSplit()]
    for cv in group_cvs:
        gs = GridSearchCV(clf, grid, cv=cv)
        assert_raise_message(ValueError,
                             "The 'groups' parameter should not be None.",
                             gs.fit, X, y)
        gs.fit(X, y, groups=groups)

    non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
    for cv in non_group_cvs:
        gs = GridSearchCV(clf, grid, cv=cv)
        # Should not raise an error
        gs.fit(X, y) 
Example 7
Project: Fall-Detection-with-CNNs-and-Optical-Flow   Author: AdrianNunez   File: temporalnet_combined.py    License: MIT License 6 votes vote down vote up
def divide_train_val(zeroes, ones, val_size):
    """ sss = StratifiedShuffleSplit(n_splits=1,
                    test_size=val_size/2,
                    random_state=7)
    indices_0 = sss.split(np.zeros(len(zeroes)), zeroes)
    indices_1 = sss.split(np.zeros(len(ones)), ones)
    train_indices_0, val_indices_0 = indices_0.next()
    train_indices_1, val_indices_1 = indices_1.next() """

    rand0 = np.random.permutation(len(zeroes))
    train_indices_0 = zeroes[rand0[val_size//2:]]
    val_indices_0 = zeroes[rand0[:val_size//2]]
    rand1 = np.random.permutation(len(ones))
    train_indices_1 = ones[rand1[val_size//2:]]
    val_indices_1 = ones[rand1[:val_size//2]]

    return (train_indices_0, train_indices_1,
            val_indices_0, val_indices_1) 
Example 8
Project: self-ensemble-visual-domain-adapt-photo   Author: Britefury   File: image_dataset.py    License: MIT License 6 votes vote down vote up
def subset_indices(d_source, d_target, subsetsize, subsetseed):
    if subsetsize > 0:
        if subsetseed != 0:
            subset_rng = np.random.RandomState(subsetseed)
        else:
            subset_rng = np.random
        strat = StratifiedShuffleSplit(n_splits=1, test_size=subsetsize, random_state=subset_rng)
        shuf = ShuffleSplit(n_splits=1, test_size=subsetsize, random_state=subset_rng)
        _, source_indices = next(strat.split(d_source.y, d_source.y))
        n_src = source_indices.shape[0]
        if d_target.has_ground_truth:
            _, target_indices = next(strat.split(d_target.y, d_target.y))
        else:
            _, target_indices = next(shuf.split(np.arange(len(d_target.images))))
        n_tgt = target_indices.shape[0]
    else:
        source_indices = None
        target_indices = None
        n_src = len(d_source.images)
        n_tgt = len(d_target.images)

    return source_indices, target_indices, n_src, n_tgt 
Example 9
Project: nonconformist   Author: donlnz   File: acp.py    License: MIT License 6 votes vote down vote up
def gen_samples(self, y, n_samples, problem_type):
		if problem_type == 'classification':
			splits = StratifiedShuffleSplit(
					n_splits=n_samples,
					test_size=self.cal_portion
				)

			split_ = splits.split(np.zeros((y.size, 1)), y)
		
		else:
			splits = ShuffleSplit(
				n_splits=n_samples,
				test_size=self.cal_portion
			)

			split_ = splits.split(np.zeros((y.size, 1)))

		for train, cal in split_:
			yield train, cal


# -----------------------------------------------------------------------------
# Conformal ensemble
# ----------------------------------------------------------------------------- 
Example 10
Project: twitter-stock-recommendation   Author: alvarobartt   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_2d_y():
    # smoke test for 2d y and multi-label
    n_samples = 30
    rng = np.random.RandomState(1)
    X = rng.randint(0, 3, size=(n_samples, 2))
    y = rng.randint(0, 3, size=(n_samples,))
    y_2d = y.reshape(-1, 1)
    y_multilabel = rng.randint(0, 2, size=(n_samples, 3))
    groups = rng.randint(0, 3, size=(n_samples,))
    splitters = [LeaveOneOut(), LeavePOut(p=2), KFold(), StratifiedKFold(),
                 RepeatedKFold(), RepeatedStratifiedKFold(),
                 ShuffleSplit(), StratifiedShuffleSplit(test_size=.5),
                 GroupShuffleSplit(), LeaveOneGroupOut(),
                 LeavePGroupsOut(n_groups=2), GroupKFold(), TimeSeriesSplit(),
                 PredefinedSplit(test_fold=groups)]
    for splitter in splitters:
        list(splitter.split(X, y, groups))
        list(splitter.split(X, y_2d, groups))
        try:
            list(splitter.split(X, y_multilabel, groups))
        except ValueError as e:
            allowed_target_types = ('binary', 'multiclass')
            msg = "Supported target types are: {}. Got 'multilabel".format(
                allowed_target_types)
            assert msg in str(e) 
Example 11
Project: twitter-stock-recommendation   Author: alvarobartt   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_stratified_shuffle_split_multilabel():
    # fix for issue 9037
    for y in [np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
              np.array([[0, 1], [1, 1], [1, 1], [0, 1]])]:
        X = np.ones_like(y)
        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
        train, test = next(sss.split(X=X, y=y))
        y_train = y[train]
        y_test = y[test]

        # no overlap
        assert_array_equal(np.intersect1d(train, test), [])

        # complete partition
        assert_array_equal(np.union1d(train, test), np.arange(len(y)))

        # correct stratification of entire rows
        # (by design, here y[:, 0] uniquely determines the entire row of y)
        expected_ratio = np.mean(y[:, 0])
        assert_equal(expected_ratio, np.mean(y_train[:, 0]))
        assert_equal(expected_ratio, np.mean(y_test[:, 0])) 
Example 12
Project: twitter-stock-recommendation   Author: alvarobartt   File: test_split.py    License: MIT License 6 votes vote down vote up
def test_stratified_shuffle_split_multilabel_many_labels():
    # fix in PR #9922: for multilabel data with > 1000 labels, str(row)
    # truncates with an ellipsis for elements in positions 4 through
    # len(row) - 4, so labels were not being correctly split using the powerset
    # method for transforming a multilabel problem to a multiclass one; this
    # test checks that this problem is fixed.
    row_with_many_zeros = [1, 0, 1] + [0] * 1000 + [1, 0, 1]
    row_with_many_ones = [1, 0, 1] + [1] * 1000 + [1, 0, 1]
    y = np.array([row_with_many_zeros] * 10 + [row_with_many_ones] * 100)
    X = np.ones_like(y)

    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
    train, test = next(sss.split(X=X, y=y))
    y_train = y[train]
    y_test = y[test]

    # correct stratification of entire rows
    # (by design, here y[:, 4] uniquely determines the entire row of y)
    expected_ratio = np.mean(y[:, 4])
    assert_equal(expected_ratio, np.mean(y_train[:, 4]))
    assert_equal(expected_ratio, np.mean(y_test[:, 4])) 
Example 13
Project: HungaBunga   Author: ypeleg   File: core.py    License: MIT License 5 votes vote down vote up
def cv_clf(x, y, test_size = 0.2, n_splits = 5, random_state=None, doesUpsample = True):
    sss_obj = sss(n_splits, test_size, random_state=random_state).split(x, y)
    if not doesUpsample: yield sss_obj
    for train_inds, valid_inds in sss_obj: yield (upsample_indices_clf(train_inds, y[train_inds]), valid_inds) 
Example 14
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 5 votes vote down vote up
def test_stratified_shuffle_split_respects_test_size():
    y = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2])
    test_size = 5
    train_size = 10
    sss = StratifiedShuffleSplit(6, test_size=test_size, train_size=train_size,
                                 random_state=0).split(np.ones(len(y)), y)
    for train, test in sss:
        assert_equal(len(train), train_size)
        assert_equal(len(test), test_size) 
Example 15
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 5 votes vote down vote up
def test_stratified_shuffle_split_iter():
    ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
          np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
          np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
          np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
          np.array([-1] * 800 + [1] * 50),
          np.concatenate([[i] * (100 + i) for i in range(11)]),
          [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
          ['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3'],
          ]

    for y in ys:
        sss = StratifiedShuffleSplit(6, test_size=0.33,
                                     random_state=0).split(np.ones(len(y)), y)
        y = np.asanyarray(y)  # To make it indexable for y[train]
        # this is how test-size is computed internally
        # in _validate_shuffle_split
        test_size = np.ceil(0.33 * len(y))
        train_size = len(y) - test_size
        for train, test in sss:
            assert_array_equal(np.unique(y[train]), np.unique(y[test]))
            # Checks if folds keep classes proportions
            p_train = (np.bincount(np.unique(y[train],
                                   return_inverse=True)[1]) /
                       float(len(y[train])))
            p_test = (np.bincount(np.unique(y[test],
                                  return_inverse=True)[1]) /
                      float(len(y[test])))
            assert_array_almost_equal(p_train, p_test, 1)
            assert_equal(len(train) + len(test), y.size)
            assert_equal(len(train), train_size)
            assert_equal(len(test), test_size)
            assert_array_equal(np.lib.arraysetops.intersect1d(train, test), []) 
Example 16
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 5 votes vote down vote up
def test_stratifiedshufflesplit_list_input():
    # Check that when y is a list / list of string labels, it works.
    sss = StratifiedShuffleSplit(test_size=2, random_state=42)
    X = np.ones(7)
    y1 = ['1'] * 4 + ['0'] * 3
    y2 = np.hstack((np.ones(4), np.zeros(3)))
    y3 = y2.tolist()

    np.testing.assert_equal(list(sss.split(X, y1)),
                            list(sss.split(X, y2)))
    np.testing.assert_equal(list(sss.split(X, y3)),
                            list(sss.split(X, y2))) 
Example 17
Project: Mastering-Elasticsearch-7.0   Author: PacktPublishing   File: test_split.py    License: MIT License 5 votes vote down vote up
def test_nested_cv():
    # Test if nested cross validation works with different combinations of cv
    rng = np.random.RandomState(0)

    X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
    groups = rng.randint(0, 5, 15)

    cvs = [LeaveOneGroupOut(), LeaveOneOut(), GroupKFold(), StratifiedKFold(),
           StratifiedShuffleSplit(n_splits=3, random_state=0)]

    for inner_cv, outer_cv in combinations_with_replacement(cvs, 2):
        gs = GridSearchCV(Ridge(), param_grid={'alpha': [1, .1]},
                          cv=inner_cv, error_score='raise', iid=False)
        cross_val_score(gs, X=X, y=y, groups=groups, cv=outer_cv,
                        fit_params={'groups': groups}) 
Example 18
Project: skorch   Author: skorch-dev   File: dataset.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _is_stratified(self, cv):
        return isinstance(cv, (StratifiedKFold, StratifiedShuffleSplit)) 
Example 19
Project: skorch   Author: skorch-dev   File: dataset.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _check_cv_float(self):
        cv_cls = StratifiedShuffleSplit if self.stratified else ShuffleSplit
        return cv_cls(test_size=self.cv, random_state=self.random_state) 
Example 20
Project: autoreject   Author: autoreject   File: autoreject.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _compute_thresholds(epochs, method='bayesian_optimization',
                        random_state=None, picks=None, augment=True,
                        dots=None, verbose='progressbar', n_jobs=1):
    if method not in ['bayesian_optimization', 'random_search']:
        raise ValueError('`method` param not recognized')
    picks = _handle_picks(info=epochs.info, picks=picks)
    _check_data(epochs, picks, verbose=verbose,
                ch_constraint='data_channels')
    picks_by_type = _get_picks_by_type(picks=picks, info=epochs.info)
    picks_by_type = None if len(picks_by_type) == 1 else picks_by_type  # XXX
    if picks_by_type is not None:
        threshes = dict()
        for ch_type, this_picks in picks_by_type:
            threshes.update(_compute_thresholds(
                epochs=epochs, method=method, random_state=random_state,
                picks=this_picks, augment=augment, dots=dots,
                verbose=verbose, n_jobs=n_jobs))
    else:
        n_epochs = len(epochs)
        data, y = epochs.get_data(), np.ones((n_epochs, ))
        if augment:
            epochs_interp = _clean_by_interp(epochs, picks=picks,
                                             dots=dots, verbose=verbose)
            # non-data channels will be duplicate
            data = np.concatenate((epochs.get_data(),
                                   epochs_interp.get_data()), axis=0)
            y = np.r_[np.zeros((n_epochs, )), np.ones((n_epochs, ))]
        cv = StratifiedShuffleSplit(n_splits=10, test_size=0.2,
                                    random_state=random_state)

        ch_names = epochs.ch_names

        my_thresh = delayed(_compute_thresh)
        parallel = Parallel(n_jobs=n_jobs, verbose=0)
        desc = 'Computing thresholds ...'
        threshes = parallel(
            my_thresh(data[:, pick], cv=cv, method=method, y=y,
                      random_state=random_state)
            for pick in _pbar(picks, desc=desc, verbose=verbose))
        threshes = {ch_names[p]: thresh for p, thresh in zip(picks, threshes)}
    return threshes 
Example 21
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(X, y):
    cv = StratifiedShuffleSplit(n_splits=8, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 22
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(X, y):
    cv = StratifiedShuffleSplit(n_splits=2, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 23
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(folder_X, y):
    _, X = folder_X
    cv = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 24
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(X, y):
    cv = StratifiedShuffleSplit(n_splits=2, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 25
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(X, y):
    cv = StratifiedShuffleSplit(n_splits=8, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 26
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(X, y):
    cv = StratifiedShuffleSplit(n_splits=8, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 27
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(X, y):
    cv = StratifiedShuffleSplit(n_splits=8, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 28
Project: ramp-workflow   Author: paris-saclay-cds   File: problem.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def get_cv(folder_X, y):
    _, X = folder_X
    cv = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=57)
    return cv.split(X, y) 
Example 29
Project: NeMo   Author: NVIDIA   File: scp_to_manifest.py    License: Apache License 2.0 5 votes vote down vote up
def main(scp, id, out, split=False):
    if os.path.exists(out):
        os.remove(out)
    scp_file = open(scp, 'r').readlines()

    lines = []
    speakers = []
    with open(out, 'w') as outfile:
        for line in tqdm(scp_file):
            line = line.strip()
            y, sr = l.load(line, sr=None)
            dur = l.get_duration(y=y, sr=sr)
            speaker = line.split('/')[id]
            speaker = list(speaker)
            speaker = ''.join(speaker)
            speakers.append(speaker)
            meta = {"audio_filepath": line, "duration": float(dur), "label": speaker}
            lines.append(meta)
            json.dump(meta, outfile)
            outfile.write("\n")

    path = os.path.dirname(out)
    if split:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
        for train_idx, test_idx in sss.split(speakers, speakers):
            logging.info(len(train_idx))

        out = os.path.join(path, 'train.json')
        write_file(out, lines, train_idx)
        out = os.path.join(path, 'dev.json')
        write_file(out, lines, test_idx) 
Example 30
Project: gumpy   Author: gumpy-bci   File: split.py    License: MIT License 5 votes vote down vote up
def  stratified_shuffle_Split(features, labels, n_splits,test_size,random_state):

    """Stratified ShuffleSplit cross-validator
    """
    cv = StratifiedShuffleSplit(n_splits, test_size, random_state=random_state)
    for train_index, test_index in cv.split(features,labels):
        X_train = features[train_index]
        X_test = features[test_index]
        Y_train = labels[train_index]
        Y_test = labels[test_index]
    return X_train, X_test, Y_train, Y_test


#Random permutation cross-validator