Python datasets.get_dataset() Examples

The following are 10 code examples of datasets.get_dataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module datasets , or try the search function .
Example #1
Source File: datasets_test.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def test_CRUD_dataset(capsys, crud_dataset_id):
    datasets.create_dataset(
        project_id,
        cloud_region,
        crud_dataset_id)

    datasets.get_dataset(
        project_id, cloud_region, crud_dataset_id)

    datasets.list_datasets(
        project_id, cloud_region)

    datasets.delete_dataset(
        project_id, cloud_region, crud_dataset_id)

    out, _ = capsys.readouterr()

    # Check that create/get/list/delete worked
    assert 'Created dataset' in out
    assert 'Time zone' in out
    assert 'Dataset' in out
    assert 'Deleted dataset' in out 
Example #2
Source File: datasets_test.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def test_CRUD_dataset(capsys):
    datasets.create_dataset(
        service_account_json,
        project_id,
        cloud_region,
        dataset_id)

    datasets.get_dataset(
        service_account_json, project_id, cloud_region, dataset_id)

    datasets.list_datasets(
        service_account_json, project_id, cloud_region)

    # Test and also clean up
    datasets.delete_dataset(
        service_account_json, project_id, cloud_region, dataset_id)

    out, _ = capsys.readouterr()

    # Check that create/get/list/delete worked
    assert 'Created dataset' in out
    assert 'Time zone' in out
    assert 'Dataset' in out
    assert 'Deleted dataset' in out 
Example #3
Source File: evaluations.py    From pyslam with GNU General Public License v3.0 6 votes vote down vote up
def extract_reg_feat(config):
    """Extract regional features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'reg'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('reg_model')(config['pretrained']['reg_model'], **(config['reg_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            reg_f = h5py.File(dump_path, 'a')
            if 'reg_feat' not in reg_f or config['reg_feat']['overwrite']:
                reg_feat = model.run_test_data(data['image'])
                if 'reg_feat' in reg_f:
                    del reg_f['reg_feat']
                _ = reg_f.create_dataset('reg_feat', data=reg_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
Example #4
Source File: evaluations.py    From pyslam with GNU General Public License v3.0 6 votes vote down vote up
def format_data(config):
    """Post-processing and generate custom files."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'post_format'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    idx = 0
    while True:
        try:
            data = next(test_set)
            dataset.format_data(data)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break 
Example #5
Source File: worker.py    From super-simple-distributed-keras with MIT License 5 votes vote down vote up
def evaluate_network(network, dataset):
        """Spawn a training sessions.

        Args:
            network (dict): The JSON definition of the network
            dataset (string): The name of the dataset to use
        """
        # Get the dataset.
        _, batch_size, _, x_train, x_test, y_train, y_test = get_dataset(dataset)

        model = model_from_json(network)
        model.compile(loss='categorical_crossentropy', optimizer='adam',
                      metrics=['accuracy'])

        model.fit(x_train, y_train,
                  batch_size=batch_size,
                  epochs=10000,  # essentially infinite, uses early stopping
                  verbose=1,
                  validation_data=(x_test, y_test),
                  callbacks=[early_stopper])

        score = model.evaluate(x_test, y_test, verbose=0)

        metrics = {'loss': score[0], 'accuracy': score[1]}

        return metrics 
Example #6
Source File: training.py    From blitznet with MIT License 5 votes vote down vote up
def main(argv=None):  # pylint: disable=unused-argument
    assert args.detect or args.segment, "Either detect or segment should be True"
    if args.trunk == 'resnet50':
        net = ResNet
        depth = 50
    if args.trunk == 'vgg16':
        net = VGG
        depth = 16

    net = net(config=net_config, depth=depth, training=True, weight_decay=args.weight_decay)

    if args.dataset == 'voc07':
        dataset = get_dataset('voc07_trainval')
    if args.dataset == 'voc12-trainval':
        dataset = get_dataset('voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc12-train':
        dataset = get_dataset('voc12-train-segmentation')
    if args.dataset == 'voc12-val':
        dataset = get_dataset('voc12-val-segmentation')
    if args.dataset == 'voc07+12':
        dataset = get_dataset('voc07_trainval', 'voc12_train', 'voc12_val')
    if args.dataset == 'voc07+12-segfull':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc07+12-segmentation':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation')
    if args.dataset == 'coco':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-train2014-*', 'coco-valminusminival2014-*')
    if args.dataset == 'coco-seg':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-seg-train2014-*', 'coco-seg-valminusminival2014-*')

    train(dataset, net, net_config) 
Example #7
Source File: test.py    From amdim-public with MIT License 5 votes vote down vote up
def main():

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)

    _, test_loader, _ = build_dataset(dataset=dataset,
                            batch_size=args.batch_size,
                            input_dir=args.input_dir)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer()
   
    model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
    model = model.to(torch_device)
    model, _ = mixed_precision.initialize(model, None)

    test_stats = AverageMeterSet()
    test(model, test_loader, torch_device, test_stats)
    stat_str = test_stats.pretty_string(ignore=model.tasks)
    print(stat_str) 
Example #8
Source File: evaluations.py    From pyslam with GNU General Public License v3.0 5 votes vote down vote up
def extract_loc_feat(config):
    """Extract local features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'loc'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('loc_model')(config['pretrained']['loc_model'], **(config['loc_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            loc_f = h5py.File(dump_path, 'a')
            if 'loc_info' not in loc_f and 'kpt' not in loc_f or config['loc_feat']['overwrite']:
                # detect SIFT keypoints and crop image patches.
                loc_feat, kpt_mb, npy_kpts, cv_kpts, _ = model.run_test_data(data['image'])
                loc_info = np.concatenate((npy_kpts, loc_feat, kpt_mb), axis=-1)
                raw_kpts = [np.array((i.pt[0], i.pt[1], i.size, i.angle, i.response))
                            for i in cv_kpts]
                raw_kpts = np.stack(raw_kpts, axis=0)
                loc_info = np.concatenate((raw_kpts, loc_info), axis=-1)
                if 'loc_info' in loc_f or 'kpt' in loc_f:
                    del loc_f['loc_info']
                _ = loc_f.create_dataset('loc_info', data=loc_info)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
Example #9
Source File: evaluations.py    From pyslam with GNU General Public License v3.0 5 votes vote down vote up
def extract_aug_feat(config):
    """Extract augmented features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'aug'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('aug_model')(config['pretrained']['loc_model'], **(config['aug_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            aug_f = h5py.File(dump_path, 'a')
            if 'aug_feat' not in aug_f or config['aug_feat']['overwrite']:
                aug_feat, _ = model.run_test_data(data['dump_data'])
                if 'aug_feat' in aug_f:
                    del aug_f['aug_feat']
                if aug_feat.dtype == np.uint8:
                    _ = aug_f.create_dataset('aug_feat', data=aug_feat, dtype='uint8')
                else:
                    _ = aug_f.create_dataset('aug_feat', data=aug_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
Example #10
Source File: train.py    From amdim-public with MIT License 4 votes vote down vote up
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args.output_dir, args.run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args.batch_size,
                      input_dir=args.input_dir,
                      labeled_only=args.classifiers)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args.output_dir)
    if args.cpt_load_path:
        model = checkpointer.restore_model_from_checkpoint(
                    args.cpt_load_path, 
                    training_classifier=args.classifiers)
    else:
        # create new model with random parameters
        model = Model(ndf=args.ndf, n_classes=num_classes, n_rkhs=args.n_rkhs,
                    tclip=args.tclip, n_depth=args.n_depth, encoder_size=encoder_size,
                    use_bn=(args.use_bn == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)


    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args.classifiers else train_self_supervised
    task(model, args.learning_rate, dataset, train_loader,
         test_loader, stat_tracker, checkpointer, args.output_dir, torch_device)