Python train.main() Examples
The following are 30
code examples of train.main().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
train
, or try the search function
.
Example #1
Source File: test_binaries.py From crosentgec with GNU General Public License v3.0 | 6 votes |
def train_translation_model(data_dir, arch, extra_flags=None): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'translation', data_dir, '--save-dir', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '0.05', '--max-tokens', '500', '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', '--source-lang', 'in', '--target-lang', 'out', ] + (extra_flags or []), ) train.main(train_args)
Example #2
Source File: test_binaries.py From helo_word with Apache License 2.0 | 6 votes |
def train_translation_model(data_dir, arch, extra_flags=None): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'translation', data_dir, '--save-dir', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '0.05', '--max-tokens', '500', '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', '--source-lang', 'in', '--target-lang', 'out', ] + (extra_flags or []), ) train.main(train_args)
Example #3
Source File: test_binaries.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def train_language_model(data_dir, arch): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'language_modeling', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '1.0', '--criterion', 'adaptive_loss', '--adaptive-softmax-cutoff', '5,10,15', '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', '--decoder-embed-dim', '280', '--max-tokens', '500', '--tokens-per-sample', '500', '--save-dir', data_dir, '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', ], ) train.main(train_args)
Example #4
Source File: test_binaries.py From helo_word with Apache License 2.0 | 6 votes |
def train_language_model(data_dir, arch): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'language_modeling', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '0.1', '--criterion', 'adaptive_loss', '--adaptive-softmax-cutoff', '5,10,15', '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', '--decoder-embed-dim', '280', '--max-tokens', '500', '--tokens-per-sample', '500', '--save-dir', data_dir, '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', '--ddp-backend', 'no_c10d', ], ) train.main(train_args)
Example #5
Source File: test_binaries.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def train_translation_model(data_dir, arch, extra_flags=None): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'translation', data_dir, '--save-dir', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '0.05', '--max-tokens', '500', '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', '--source-lang', 'in', '--target-lang', 'out', ] + (extra_flags or []), ) train.main(train_args)
Example #6
Source File: test_binaries.py From crosentgec with GNU General Public License v3.0 | 6 votes |
def train_language_model(data_dir, arch): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ '--task', 'language_modeling', data_dir, '--arch', arch, '--optimizer', 'nag', '--lr', '1.0', '--criterion', 'adaptive_loss', '--adaptive-softmax-cutoff', '5,10,15', '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', '--decoder-embed-dim', '280', '--max-tokens', '500', '--tokens-per-sample', '500', '--save-dir', data_dir, '--max-epoch', '1', '--no-progress-bar', '--distributed-world-size', '1', ], ) train.main(train_args)
Example #7
Source File: test_binaries.py From crosentgec with GNU General Public License v3.0 | 6 votes |
def generate_main(data_dir): generate_parser = options.get_generation_parser() generate_args = options.parse_args_and_arch( generate_parser, [ data_dir, '--path', os.path.join(data_dir, 'checkpoint_last.pt'), '--beam', '3', '--batch-size', '64', '--max-len-b', '5', '--gen-subset', 'valid', '--no-progress-bar', ], ) # evaluate model in batch mode generate.main(generate_args) # evaluate model interactively generate_args.buffer_size = 0 generate_args.max_sentences = None orig_stdin = sys.stdin sys.stdin = StringIO('h e l l o\n') interactive.main(generate_args) sys.stdin = orig_stdin
Example #8
Source File: test_binaries.py From helo_word with Apache License 2.0 | 5 votes |
def generate_main(data_dir, extra_flags=None): generate_parser = options.get_generation_parser() generate_args = options.parse_args_and_arch( generate_parser, [ data_dir, '--path', os.path.join(data_dir, 'checkpoint_last.pt'), '--beam', '3', '--batch-size', '64', '--max-len-b', '5', '--gen-subset', 'valid', '--no-progress-bar', '--print-alignment', ] + (extra_flags or []), ) # evaluate model in batch mode generate.main(generate_args) # evaluate model interactively generate_args.buffer_size = 0 generate_args.input = '-' generate_args.max_sentences = None orig_stdin = sys.stdin sys.stdin = StringIO('h e l l o\n') interactive.main(generate_args) sys.stdin = orig_stdin
Example #9
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_full_flow(self, mock_data_provider): FLAGS.eval_dir = self.get_temp_dir() FLAGS.batch_size = 16 FLAGS.max_number_of_steps = 2 FLAGS.noise_dims = 3 # Construct mock inputs. mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None) train.main(None)
Example #10
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_main(self, mock_provide_data): FLAGS.image_file_patterns = ['/tmp/A/*.jpg', '/tmp/B/*.jpg', '/tmp/C/*.jpg'] FLAGS.max_number_of_steps = 10 FLAGS.batch_size = 2 num_domains = 3 images_shape = [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3] img_list = [tf.zeros(images_shape)] * num_domains lbl_list = [tf.one_hot([0] * FLAGS.batch_size, num_domains)] * num_domains mock_provide_data.return_value = (img_list, lbl_list) train.main(None)
Example #11
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_build_graph(self, conditional, use_sync_replicas): FLAGS.max_number_of_steps = 0 FLAGS.conditional = conditional FLAGS.use_sync_replicas = use_sync_replicas FLAGS.batch_size = 16 # Mock input pipeline. mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = ( mock_imgs, mock_lbls, None, None) train.main(None)
Example #12
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss, mock_define_model, mock_data_provider, mock_gfile): FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg' FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg' FLAGS.batch_size = 3 FLAGS.patch_size = 8 FLAGS.generator_lr = 0.02 FLAGS.discriminator_lr = 0.3 FLAGS.train_log_dir = '/tmp/foo' FLAGS.master = 'master' FLAGS.task = 0 FLAGS.cycle_consistency_loss_weight = 2.0 FLAGS.max_number_of_steps = 1 mock_data_provider.provide_custom_data.return_value = ( tf.zeros([3, 2, 2, 3], dtype=tf.float32), tf.zeros([3, 2, 2, 3], dtype=tf.float32)) train.main(None) mock_data_provider.provide_custom_data.assert_called_once_with( ['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8) mock_define_model.assert_called_once_with(mock.ANY, mock.ANY) mock_cyclegan_loss.assert_called_once_with( mock_define_model.return_value, cycle_consistency_loss_weight=2.0, tensor_pool_fn=mock.ANY) mock_define_train_ops.assert_called_once_with( mock_define_model.return_value, mock_cyclegan_loss.return_value) mock_gan_train.assert_called_once_with( mock_define_train_ops.return_value, '/tmp/foo', get_hooks_fn=mock.ANY, hooks=mock.ANY, master='master', is_chief=True)
Example #13
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_build_graph(self, gan_type): FLAGS.max_number_of_steps = 0 FLAGS.gan_type = gan_type # Mock input pipeline. mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = ( mock_imgs, mock_lbls, None) train.main(None)
Example #14
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_build_graph(self, weight_factor): FLAGS.max_number_of_steps = 0 FLAGS.weight_factor = weight_factor FLAGS.batch_size = 9 FLAGS.patch_size = 32 mock_imgs = np.zeros( [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3], dtype=np.float32) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = mock_imgs train.main(None)
Example #15
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_build_graph(self, weight_factor): FLAGS.max_number_of_steps = 0 FLAGS.weight_factor = weight_factor batch_size = 3 patch_size = 16 FLAGS.batch_size = batch_size FLAGS.patch_size = patch_size mock_imgs = np.zeros([batch_size, patch_size, patch_size, 3], dtype=np.float32) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = mock_imgs train.main(None)
Example #16
Source File: train_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def test_main(self): FLAGS.image_file_patterns = [ os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'black/*.jpg'), os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'blond/*.jpg'), os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'brown/*.jpg'), ] FLAGS.max_number_of_steps = 1 FLAGS.steps_per_eval = 1 FLAGS.batch_size = 1 train.main(None, _test_generator, _test_discriminator)
Example #17
Source File: test_binaries.py From helo_word with Apache License 2.0 | 5 votes |
def preprocess_translation_data(data_dir, extra_flags=None): preprocess_parser = options.get_preprocessing_parser() preprocess_args = preprocess_parser.parse_args( [ '--source-lang', 'in', '--target-lang', 'out', '--trainpref', os.path.join(data_dir, 'train'), '--validpref', os.path.join(data_dir, 'valid'), '--testpref', os.path.join(data_dir, 'test'), '--thresholdtgt', '0', '--thresholdsrc', '0', '--destdir', data_dir, ] + (extra_flags or []), ) preprocess.main(preprocess_args)
Example #18
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_main(self): FLAGS.image_file_patterns = [ os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'black/*.jpg'), os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'blond/*.jpg'), os.path.join(FLAGS.test_srcdir, TESTDATA_DIR, 'brown/*.jpg'), ] FLAGS.max_number_of_steps = 1 FLAGS.steps_per_eval = 1 FLAGS.batch_size = 1 train.main(None, _test_generator, _test_discriminator)
Example #19
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_build_graph(self, weight_factor): FLAGS.max_number_of_steps = 0 FLAGS.weight_factor = weight_factor batch_size = 3 patch_size = 16 FLAGS.batch_size = batch_size FLAGS.patch_size = patch_size mock_imgs = np.zeros([batch_size, patch_size, patch_size, 3], dtype=np.float32) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = mock_imgs train.main(None)
Example #20
Source File: test_binaries.py From helo_word with Apache License 2.0 | 5 votes |
def preprocess_lm_data(data_dir): preprocess_parser = options.get_preprocessing_parser() preprocess_args = preprocess_parser.parse_args([ '--only-source', '--trainpref', os.path.join(data_dir, 'train.out'), '--validpref', os.path.join(data_dir, 'valid.out'), '--testpref', os.path.join(data_dir, 'test.out'), '--destdir', data_dir, ]) preprocess.main(preprocess_args)
Example #21
Source File: test_binaries.py From crosentgec with GNU General Public License v3.0 | 5 votes |
def preprocess_translation_data(data_dir): preprocess_parser = preprocess.get_parser() preprocess_args = preprocess_parser.parse_args([ '--source-lang', 'in', '--target-lang', 'out', '--trainpref', os.path.join(data_dir, 'train'), '--validpref', os.path.join(data_dir, 'valid'), '--testpref', os.path.join(data_dir, 'test'), '--thresholdtgt', '0', '--thresholdsrc', '0', '--destdir', data_dir, ]) preprocess.main(preprocess_args)
Example #22
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_full_flow(self, mock_data_provider): FLAGS.eval_dir = self.get_temp_dir() FLAGS.batch_size = 16 FLAGS.max_number_of_steps = 2 FLAGS.noise_dims = 3 # Construct mock inputs. mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None) train.main(None)
Example #23
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_main(self, mock_provide_data): FLAGS.image_file_patterns = ['/tmp/A/*.jpg', '/tmp/B/*.jpg', '/tmp/C/*.jpg'] FLAGS.max_number_of_steps = 10 FLAGS.batch_size = 2 num_domains = 3 images_shape = [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3] img_list = [tf.zeros(images_shape)] * num_domains lbl_list = [tf.one_hot([0] * FLAGS.batch_size, num_domains)] * num_domains mock_provide_data.return_value = (img_list, lbl_list) train.main(None)
Example #24
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_build_graph(self, conditional, use_sync_replicas): FLAGS.max_number_of_steps = 0 FLAGS.conditional = conditional FLAGS.use_sync_replicas = use_sync_replicas FLAGS.batch_size = 16 # Mock input pipeline. mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = ( mock_imgs, mock_lbls, None, None) train.main(None)
Example #25
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss, mock_define_model, mock_data_provider, mock_gfile): FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg' FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg' FLAGS.batch_size = 3 FLAGS.patch_size = 8 FLAGS.generator_lr = 0.02 FLAGS.discriminator_lr = 0.3 FLAGS.train_log_dir = '/tmp/foo' FLAGS.master = 'master' FLAGS.task = 0 FLAGS.cycle_consistency_loss_weight = 2.0 FLAGS.max_number_of_steps = 1 mock_data_provider.provide_custom_data.return_value = ( tf.zeros([3, 2, 2, 3], dtype=tf.float32), tf.zeros([3, 2, 2, 3], dtype=tf.float32)) train.main(None) mock_data_provider.provide_custom_data.assert_called_once_with( ['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8) mock_define_model.assert_called_once_with(mock.ANY, mock.ANY) mock_cyclegan_loss.assert_called_once_with( mock_define_model.return_value, cycle_consistency_loss_weight=2.0, tensor_pool_fn=mock.ANY) mock_define_train_ops.assert_called_once_with( mock_define_model.return_value, mock_cyclegan_loss.return_value) mock_gan_train.assert_called_once_with( mock_define_train_ops.return_value, '/tmp/foo', get_hooks_fn=mock.ANY, hooks=mock.ANY, master='master', is_chief=True)
Example #26
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_build_graph(self, gan_type): FLAGS.max_number_of_steps = 0 FLAGS.gan_type = gan_type # Mock input pipeline. mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = ( mock_imgs, mock_lbls, None) train.main(None)
Example #27
Source File: train_test.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def test_build_graph(self, weight_factor): FLAGS.max_number_of_steps = 0 FLAGS.weight_factor = weight_factor FLAGS.batch_size = 9 FLAGS.patch_size = 32 mock_imgs = np.zeros( [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3], dtype=np.float32) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = mock_imgs train.main(None)
Example #28
Source File: train_test.py From Gun-Detector with Apache License 2.0 | 5 votes |
def _test_build_graph_helper(self, conditional, use_sync_replicas): FLAGS.max_number_of_steps = 0 FLAGS.conditional = conditional FLAGS.use_sync_replicas = use_sync_replicas FLAGS.batch_size = 16 # Mock input pipeline. mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) with mock.patch.object(train, 'data_provider') as mock_data_provider: mock_data_provider.provide_data.return_value = ( mock_imgs, mock_lbls, None, None) train.main(None)
Example #29
Source File: test_binaries.py From crosentgec with GNU General Public License v3.0 | 5 votes |
def preprocess_lm_data(data_dir): preprocess_parser = preprocess.get_parser() preprocess_args = preprocess_parser.parse_args([ '--only-source', '--trainpref', os.path.join(data_dir, 'train.out'), '--validpref', os.path.join(data_dir, 'valid.out'), '--testpref', os.path.join(data_dir, 'test.out'), '--destdir', data_dir, ]) preprocess.main(preprocess_args)
Example #30
Source File: train_test.py From yolo_v2 with Apache License 2.0 | 5 votes |
def test_full_flow(self, mock_data_provider): FLAGS.eval_dir = self.get_temp_dir() FLAGS.batch_size = 16 FLAGS.max_number_of_steps = 2 FLAGS.noise_dims = 3 # Construct mock inputs. mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([FLAGS.batch_size, 1], dtype=np.int32), np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1) mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None) train.main(None)