Python train.train() Examples
The following are 30
code examples of train.train().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
train
, or try the search function
.
Example #1
Source File: train_main.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def main(_): if not tf.gfile.Exists(FLAGS.train_root_dir): tf.gfile.MakeDirs(FLAGS.train_root_dir) config = _make_config_from_flags() logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()])) for stage_id in train.get_stage_ids(**config): batch_size = train.get_batch_size(stage_id, **config) tf.reset_default_graph() with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): real_images = None with tf.device('/cpu:0'), tf.name_scope('inputs'): real_images = _provide_real_images(batch_size, **config) model = train.build_model(stage_id, batch_size, real_images, **config) train.add_model_summaries(model, **config) train.train(model, **config)
Example #2
Source File: train_main.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def _provide_real_images(batch_size, **kwargs): """Provides real images.""" dataset_name = kwargs.get('dataset_name') dataset_file_pattern = kwargs.get('dataset_file_pattern') colors = kwargs['colors'] final_height, final_width = train.make_resolution_schedule( **kwargs).final_resolutions if dataset_name is not None: return data_provider.provide_data( dataset_name=dataset_name, split_name='train', batch_size=batch_size, patch_height=final_height, patch_width=final_width, colors=colors) elif dataset_file_pattern is not None: return data_provider.provide_data_from_image_files( file_pattern=dataset_file_pattern, batch_size=batch_size, patch_height=final_height, patch_width=final_width, colors=colors)
Example #3
Source File: train_main.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def main(_): if not tf.gfile.Exists(FLAGS.train_root_dir): tf.gfile.MakeDirs(FLAGS.train_root_dir) config = _make_config_from_flags() logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()])) for stage_id in train.get_stage_ids(**config): batch_size = train.get_batch_size(stage_id, **config) tf.reset_default_graph() with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): real_images = None with tf.device('/cpu:0'), tf.name_scope('inputs'): real_images = _provide_real_images(batch_size, **config) model = train.build_model(stage_id, batch_size, real_images, **config) train.add_model_summaries(model, **config) train.train(model, **config)
Example #4
Source File: optimize_parameters.py From facial-expression-recognition-svm with GNU General Public License v3.0 | 6 votes |
def function_to_minimize(hyperparams, gamma='auto', decision_function='ovr'): decision_function = hyperparams['decision_function'] gamma = hyperparams['gamma'] global current_eval global max_evals print( "#################################") print( " Evaluation {} of {}".format(current_eval, max_evals)) print( "#################################") start_time = time.time() try: accuracy = train(epochs=HYPERPARAMS.epochs_during_hyperopt, decision_function=decision_function, gamma=gamma) training_time = int(round(time.time() - start_time)) current_eval += 1 train_history.append({'accuracy':accuracy, 'decision_function':decision_function, 'gamma':gamma, 'time':training_time}) except Exception as e: print( "#################################") print( "Exception during training: {}".format(str(e))) print( "Saving train history in train_history.npy") np.save("train_history.npy", train_history) exit() return {'loss': -accuracy, 'time': training_time, 'status': STATUS_OK} # lunch the hyperparameters search
Example #5
Source File: main.py From toxic_comments with MIT License | 6 votes |
def get_kwargs(kwargs): parser = argparse.ArgumentParser(description='-f TRAIN_FILE -t TEST_FILE -o OUTPUT_FILE -e EMBEDS_FILE [-l LOGGER_FILE] [--swear-words SWEAR_FILE] [--wrong-words WRONG_WORDS_FILE] [--format-embeds FALSE]') parser.add_argument('-f', '--train', dest='train', action='store', help='/path/to/trian_file', type=str) parser.add_argument('-t', '--test', dest='test', action='store', help='/path/to/test_file', type=str) parser.add_argument('-o', '--output', dest='output', action='store', help='/path/to/output_file', type=str) parser.add_argument('-we', '--word_embeds', dest='word_embeds', action='store', help='/path/to/embeds_file', type=str) parser.add_argument('-ce', '--char_embeds', dest='char_embeds', action='store', help='/path/to/embeds_file', type=str) parser.add_argument('-c','--config', dest='config', action='store', help='/path/to/config.json', type=str) parser.add_argument('-l', '--logger', dest='logger', action='store', help='/path/to/log_file', type=str, default=None) parser.add_argument('--mode', dest='mode', action='store', help='preprocess / train / validate / all', type=str, default='all') parser.add_argument('--max-words', dest='max_words', action='store', type=int, default=300000) parser.add_argument('--use-only-exists-words', dest='use_only_exists_words', action='store_true') parser.add_argument('--swear-words', dest='swear_words', action='store', help='/path/to/swear_words_file', type=str, default=None) parser.add_argument('--wrong-words', dest='wrong_words', action='store', help='/path/to/wrong_words_file', type=str, default=None) parser.add_argument('--format-embeds', dest='format_embeds', action='store', help='file | json | pickle | binary', type=str, default='raw') parser.add_argument('--output-dir', dest='output_dir', action='store', help='/path/to/dir', type=str, default='.') parser.add_argument('--norm-prob', dest='norm_prob', action='store_true') parser.add_argument('--norm-prob-koef', dest='norm_prob_koef', action='store', type=float, default=1) parser.add_argument('--gpus', dest='gpus', action='store', help='count GPUs', type=int, default=0) for key, value in iteritems(parser.parse_args().__dict__): kwargs[key] = value
Example #6
Source File: main.py From dcase2018_baseline with MIT License | 6 votes |
def main(): flags = parse_flags() hparams = parse_hparams(flags.hparams) if flags.mode == 'train': train.train(model_name=flags.model, hparams=hparams, class_map_path=flags.class_map_path, train_csv_path=flags.train_csv_path, train_clip_dir=flags.train_clip_dir, train_dir=flags.train_dir) elif flags.mode == 'eval': evaluation.evaluate(model_name=flags.model, hparams=hparams, class_map_path=flags.class_map_path, eval_csv_path=flags.eval_csv_path, eval_clip_dir=flags.eval_clip_dir, checkpoint_path=flags.checkpoint_path) else: assert flags.mode == 'inference' inference.predict(model_name=flags.model, hparams=hparams, class_map_path=flags.class_map_path, test_clip_dir=flags.test_clip_dir, checkpoint_path=flags.checkpoint_path, predictions_csv_path=flags.predictions_csv_path)
Example #7
Source File: evaluate.py From BirdCLEF-Baseline with MIT License | 6 votes |
def evaluate(): # Clear stats stats.clearStats(True) # Parse Dataset cfg.CLASSES, TRAIN, VAL = train.parseDataset() # Build Model NET = birdnet.build_model() # Train and return best net best_net = train.train(NET, TRAIN, VAL) # Load trained net SNAPSHOT = io.loadModel(best_net) # Test snapshot MLRAP, TIME_PER_EPOCH = test.test(SNAPSHOT) result = np.array([[MLRAP]], dtype='float32') return result
Example #8
Source File: main.py From ban-vqa with MIT License | 6 votes |
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr') parser.add_argument('--epochs', type=int, default=13) parser.add_argument('--num_hid', type=int, default=1280) parser.add_argument('--model', type=str, default='ban') parser.add_argument('--op', type=str, default='c') parser.add_argument('--gamma', type=int, default=8, help='glimpse') parser.add_argument('--use_both', action='store_true', help='use both train/val datasets to train?') parser.add_argument('--use_vg', action='store_true', help='use visual genome dataset to train?') parser.add_argument('--tfidf', action='store_false', help='tfidf word embedding?') parser.add_argument('--input', type=str, default=None) parser.add_argument('--output', type=str, default='saved_models/ban') parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--seed', type=int, default=1204, help='random seed') args = parser.parse_args() return args
Example #9
Source File: main.py From xgboost-operator with Apache License 2.0 | 6 votes |
def main(args): model_storage_type = args.model_storage_type if (model_storage_type == "local" or model_storage_type == "oss"): print ( "The storage type is " + model_storage_type) else: raise Exception("Only supports storage types like local and OSS") if args.job_type == "Predict": logging.info("starting the predict job") predict(args) elif args.job_type == "Train": logging.info("starting the train job") model = train(args) if model is not None: logging.info("finish the model training, and start to dump model ") model_path = args.model_path dump_model(model, model_storage_type, model_path, args) elif args.job_type == "All": logging.info("starting the train and predict job") logging.info("Finish distributed XGBoost job")
Example #10
Source File: train_main.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def _provide_real_images(batch_size, **kwargs): """Provides real images.""" dataset_name = kwargs.get('dataset_name') dataset_file_pattern = kwargs.get('dataset_file_pattern') colors = kwargs['colors'] final_height, final_width = train.make_resolution_schedule( **kwargs).final_resolutions if dataset_name is not None: return data_provider.provide_data( dataset_name=dataset_name, split_name='train', batch_size=batch_size, patch_height=final_height, patch_width=final_width, colors=colors) elif dataset_file_pattern is not None: return data_provider.provide_data_from_image_files( file_pattern=dataset_file_pattern, batch_size=batch_size, patch_height=final_height, patch_width=final_width, colors=colors)
Example #11
Source File: cli.py From strsum with Apache License 2.0 | 6 votes |
def main(_): config = flags.FLAGS print(str(config.flag_values_dict())) os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu print('loading data...') train_batches, dev_batches, test_batches, embedding_matrix, vocab, word_to_id = load_data(config) flags.DEFINE_integer('PAD_IDX', word_to_id[PAD], 'PAD_IDX') flags.DEFINE_integer('UNK_IDX', word_to_id[UNK], 'UNK_IDX') flags.DEFINE_integer('BOS_IDX', word_to_id[BOS], 'BOS_IDX') flags.DEFINE_integer('EOS_IDX', word_to_id[EOS], 'EOS_IDX') n_embed, d_embed = embedding_matrix.shape flags.DEFINE_integer('n_embed', n_embed, 'n_embed') flags.DEFINE_integer('d_embed', d_embed, 'd_embed') maximum_iterations = max([max([d._max_sent_len(None) for d in batch]) for ct, batch in dev_batches]) flags.DEFINE_integer('maximum_iterations', maximum_iterations, 'maximum_iterations') if config.mode == 'train': train(config, train_batches, dev_batches, test_batches, embedding_matrix, vocab) elif config.mode == 'eval': evaluate(config, test_batches, vocab)
Example #12
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_train_success(self): train_root_dir = self._config['train_root_dir'] if not tf.gfile.Exists(train_root_dir): tf.gfile.MakeDirs(train_root_dir) for stage_id in train.get_stage_ids(**self._config): batch_size = train.get_batch_size(stage_id, **self._config) tf.reset_default_graph() real_images = provide_random_data(batch_size=batch_size) model = train.build_model(stage_id, batch_size, real_images, **self._config) train.add_model_summaries(model, **self._config) train.train(model, **self._config)
Example #13
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_get_batch_size(self): config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]} # batch_size_schedule is expanded to [8, 8, 8, 4, 2] # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2] for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]): self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)
Example #14
Source File: main.py From DeepLearningImplementations with MIT License | 5 votes |
def launch_training(**kwargs): # Launch training train.train(**kwargs)
Example #15
Source File: main.py From DeepLearningImplementations with MIT License | 5 votes |
def launch_training(model_name, **kwargs): # Launch training train.train(model_name, **d_params)
Example #16
Source File: main.py From DeepLearningImplementations with MIT License | 5 votes |
def launch_training(**kwargs): # Launch training train.train(**kwargs)
Example #17
Source File: main.py From arbitrary_style_transfer with MIT License | 5 votes |
def main(): if IS_TRAINING: content_imgs_path = list_images(TRAINING_CONTENT_DIR) style_imgs_path = list_images(TRAINING_STYLE_DIR) for style_weight, model_save_path in zip(STYLE_WEIGHTS, MODEL_SAVE_PATHS): print('\n>>> Begin to train the network with the style weight: %.2f\n' % style_weight) train(style_weight, content_imgs_path, style_imgs_path, ENCODER_WEIGHTS_PATH, model_save_path, logging_period=LOGGING_PERIOD, debug=True) print('\n>>> Successfully! Done all training...\n') else: content_imgs_path = list_images(INFERRING_CONTENT_DIR) style_imgs_path = list_images(INFERRING_STYLE_DIR) for style_weight, model_save_path in zip(STYLE_WEIGHTS, MODEL_SAVE_PATHS): print('\n>>> Begin to stylize images with style weight: %.2f\n' % style_weight) stylize(content_imgs_path, style_imgs_path, OUTPUTS_DIR, ENCODER_WEIGHTS_PATH, model_save_path, suffix='-' + str(style_weight)) print('\n>>> Successfully! Done all stylizing...\n')
Example #18
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_get_batch_size(self): config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]} # batch_size_schedule is expanded to [8, 8, 8, 4, 2] # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2] for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]): self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)
Example #19
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_train_success(self): train_root_dir = self._config['train_root_dir'] if not tf.gfile.Exists(train_root_dir): tf.gfile.MakeDirs(train_root_dir) for stage_id in train.get_stage_ids(**self._config): batch_size = train.get_batch_size(stage_id, **self._config) tf.reset_default_graph() real_images = provide_random_data(batch_size=batch_size) model = train.build_model(stage_id, batch_size, real_images, **self._config) train.add_model_summaries(model, **self._config) train.train(model, **self._config)
Example #20
Source File: test_dual_net.py From training with Apache License 2.0 | 5 votes |
def test_train(self): with tempfile.TemporaryDirectory() as working_dir, \ tempfile.NamedTemporaryFile() as tf_record: flags.FLAGS.work_dir = working_dir preprocessing.make_dataset_from_sgf( 'tests/example_game.sgf', tf_record.name) train.train([tf_record.name])
Example #21
Source File: facial_recognition.py From personal-photos-model with Apache License 2.0 | 5 votes |
def parse_command_line(): parser = argparse.ArgumentParser( description="""Train, validate, and test a face detection classifier that will determine if two faces are the same or different.""") parser.add_argument("-p", "--prepare-data", help="Prepare training and validation data.", action="store_true") parser.add_argument("-t", "--train", help="""Train classifier. Use --graph to generate quality graphs""", action="store_true") parser.add_argument("-g", "--graph", help="Generate training graphs.", action="store_true") parser.add_argument("--weights", help="""The trained model weights to use; if not provided defaults to the network that was just trained""", type=str, default=None) parser.add_argument("--note", help="Adds extra note onto generated quality graph.", type=str) parser.add_argument("-s", "--is_same", help="""Determines if the two images provided are the same or different. Provide relative paths to both images.""", nargs=2, type=str) parser.add_argument("--visualize", help="""Writes out various visualizations of the facial images.""", action="store_true") args = vars(parser.parse_args()) if os.environ.get("CAFFE_HOME") == None: print "You must set CAFFE_HOME to point to where Caffe is installed. Example:" print "export CAFFE_HOME=/usr/local/caffe" exit(1) # Ensure the random number generator always starts from the same place for consistent tests. random.seed(0) webface = WebFace() if args["prepare_data"] == True: webface.load_data() webface.pair_data() if args["visualize"] == True: # TODO: Adapt this to WebFace, not just LFW. visualize() if args["train"] == True: train(args["graph"], data=webface, weight_file=args["weights"], note=args["note"]) if args["is_same"] != None: # TODO: Fill this out once we have a threshold and neural network trained. images = args["is_same"] predict(images[0], images[1])
Example #22
Source File: tuning.py From SGC with MIT License | 5 votes |
def linear_objective(space): model = get_model(args.model, nfeat=feat_dict["train"].size(1), nclass=nclass, nhid=0, dropout=0, cuda=args.cuda) val_acc, _, _ = train_linear(model, feat_dict, space['weight_decay'], args.dataset=="mr") print( 'weight decay ' + str(space['weight_decay']) + '\n' + \ 'overall accuracy: ' + str(val_acc)) return {'loss': -val_acc, 'status': STATUS_OK} # Hyperparameter optimization
Example #23
Source File: stpn_main.py From STPN with Apache License 2.0 | 5 votes |
def main(_): # Parsing Arguments args = utils.parse_args() mode = args.mode train_iter = args.training_num test_iter = args.test_iter ckpt = utils.ckpt_path(args.ckpt) input_list = { 'batch_size': args.batch_size, 'beta': args.beta, 'learning_rate': args.learning_rate, 'ckpt': ckpt, 'class_threshold': args.class_th, 'scale': args.scale} config = tf.ConfigProto() config.gpu_options.allow_growth = True tf.reset_default_graph() model = StpnModel() # Run Model with tf.Session(config=config) as sess: init = tf.global_variables_initializer() if mode == 'train': sess.run(init) train(sess, model, input_list, 'rgb', train_iter) # Train RGB stream sess.run(init) train(sess, model, input_list, 'flow', train_iter) # Train FLOW stream elif mode == 'test': sess.run(init) test(sess, model, init, input_list, test_iter) # Test
Example #24
Source File: overfit.py From 3D-HourGlass-Network with MIT License | 5 votes |
def main(): opt = opts().parse() now = datetime.datetime.now() logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat())) if opt.loadModel == 'none': model = inflate(opt).cuda() elif opt.loadModel == 'scratch': model = Pose3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, opt.nRegModules, opt.nRegFrames, ref.nJoints).cuda() else : model = torch.load(opt.loadModel).cuda() train_loader = torch.utils.data.DataLoader( h36m('train',opt), batch_size = opt.dataloaderSize, shuffle = False, num_workers = int(ref.nThreads) ) optimizer = torch.optim.RMSprop( [{'params': model.parameters(), 'lr': opt.LRhg}], alpha = ref.alpha, eps = ref.epsilon, weight_decay = ref.weightDecay, momentum = ref.momentum ) for epoch in range(1, opt.nEpochs + 1): loss_train, acc_train = train(epoch, opt, train_loader, model, optimizer) logger.scalar_summary('loss_train', loss_train, epoch) logger.scalar_summary('acc_train', acc_train, epoch) logger.write('{:8f} {:8f} \n'.format(loss_train, acc_train)) logger.close()
Example #25
Source File: main.py From snip-public with MIT License | 5 votes |
def main(): args = parse_arguments() # Dataset dataset = Dataset(**vars(args)) # Reset the default graph and set a graph-level seed tf.reset_default_graph() tf.set_random_seed(9) # Model model = Model(num_classes=dataset.num_classes, **vars(args)) model.construct_model() # Session sess = tf.InteractiveSession() tf.global_variables_initializer().run() tf.local_variables_initializer().run() # Prune prune.prune(args, model, sess, dataset) # Train and test train.train(args, model, sess, dataset) test.test(args, model, sess, dataset) sess.close() sys.exit()
Example #26
Source File: augmentation_search.py From kaggle-hpa with BSD 2-Clause "Simplified" License | 5 votes |
def run(config): train_dir = config.train.dir writer = SummaryWriter(config.train.dir) utils.prepare_train_directories(config) # base_policy policy = [] score = search_once(config, policy) print('===============================') print('base score:', score) writer.add_scalar('val/f1', score, 0) policies = [] for i in range(50): policy = sample_policy() score = search_once(config, policy) writer.add_scalar('val/f1', score, i+1) policies.append((score, policy)) policies = list(sorted(policies, key=lambda v: v[0]))[-5:] with open(os.path.join(config.train.dir, 'best_policy.data'), 'w') as fid: fid.write(str([v[1] for v in policies])) for score, policy in policies: print('score:', score) print('policy:', policy)
Example #27
Source File: augmentation_search.py From kaggle-hpa with BSD 2-Clause "Simplified" License | 5 votes |
def search_once(config, policy): model = get_model(config).cuda() criterion = get_loss(config) optimizer = get_optimizer(config, model.parameters()) scheduler = get_scheduler(config, optimizer, -1) transforms = {'train': get_transform(config, 'train', params={'policies': policy}), 'val': get_transform(config, 'val')} dataloaders = {split:get_dataloader(config, split, transforms[split]) for split in ['train', 'val']} score_dict = train(config, model, dataloaders, criterion, optimizer, scheduler, None, 0) return score_dict['f1_mavg']
Example #28
Source File: runner.py From dcase2019_task2_baseline with MIT License | 5 votes |
def main(argv): hparams = model.parse_hparams(flags.hparams) if flags.mode == 'train': def split_csv(scopes): return scopes.split(',') if scopes else None train.train(model_name=flags.model, hparams=hparams, class_map_path=flags.class_map_path, train_csv_path=flags.train_csv_path, train_clip_dir=flags.train_clip_dir, train_dir=flags.train_dir, epoch_batches=flags.epoch_num_batches, warmstart_checkpoint=flags.warmstart_checkpoint, warmstart_include_scopes=split_csv(flags.warmstart_include_scopes), warmstart_exclude_scopes=split_csv(flags.warmstart_exclude_scopes)) elif flags.mode == 'eval': evaluation.evaluate(model_name=flags.model, hparams=hparams, class_map_path=flags.class_map_path, eval_csv_path=flags.eval_csv_path, eval_clip_dir=flags.eval_clip_dir, eval_dir=flags.eval_dir, train_dir=flags.train_dir) else: assert flags.mode == 'inference' inference.predict(model_name=flags.model, hparams=hparams, class_map_path=flags.class_map_path, inference_clip_dir=flags.inference_clip_dir, inference_checkpoint=flags.inference_checkpoint, predictions_csv_path=flags.predictions_csv_path)
Example #29
Source File: optimize_hyperparams.py From facial-expression-recognition-using-cnn with GNU General Public License v3.0 | 5 votes |
def function_to_minimize(hyperparams, optimizer=HYPERPARAMS.optimizer, optimizer_param=HYPERPARAMS.optimizer_param, learning_rate=HYPERPARAMS.learning_rate, keep_prob=HYPERPARAMS.keep_prob, learning_rate_decay=HYPERPARAMS.learning_rate_decay): if 'learning_rate' in hyperparams: learning_rate = hyperparams['learning_rate'] if 'learning_rate_decay' in hyperparams: learning_rate_decay = hyperparams['learning_rate_decay'] if 'keep_prob' in hyperparams: keep_prob = hyperparams['keep_prob'] if 'optimizer' in hyperparams: optimizer = hyperparams['optimizer'] if 'optimizer_param' in hyperparams: optimizer_param = hyperparams['optimizer_param'] global current_eval global max_evals print( "#################################") print( " Evaluation {} of {}".format(current_eval, max_evals)) print( "#################################") start_time = time.time() try: accuracy = train(learning_rate=learning_rate, learning_rate_decay=learning_rate_decay, optimizer=optimizer, optimizer_param=optimizer_param, keep_prob=keep_prob) training_time = int(round(time.time() - start_time)) current_eval += 1 train_history.append({'accuracy':accuracy, 'learning_rate':learning_rate, 'learning_rate_decay':learning_rate_decay, 'optimizer':optimizer, 'optimizer_param':optimizer_param, 'keep_prob':keep_prob, 'time':training_time}) except Exception as e: # exception occured during training, saving history and stopping the operation print( "#################################") print( "Exception during training: {}".format(str(e))) print( "Saving train history in train_history.npy") np.save("train_history.npy", train_history) exit() return {'loss': -accuracy, 'time': training_time, 'status': STATUS_OK} # lunch the hyperparameters search
Example #30
Source File: main.py From squash-generation with MIT License | 5 votes |
def main(): args = parser.parse_args() modify_arguments(args) # setting random seeds torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) with open(args.config_file, 'r') as stream: config = yaml.load(stream) args.config = Munch(modify_config(args, config)) logger.info(args) if args.mode == 'train': train.train(args, device) elif args.mode == 'test': pass elif args.mode == 'analysis': analysis.analyze(args, device) elif args.mode == 'generate': pass elif args.mode == 'classify': analysis.classify(args, device) elif args.mode == 'classify_coqa': analysis.classify_coqa(args, device) elif args.mode == 'classify_final': analysis.classify_final(args, device)