Python model.model.Model() Examples
The following are 11
code examples of model.model.Model().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
model.model
, or try the search function
.
Example #1
Source File: model_test.py From pytorch-project-template with Apache License 2.0 | 6 votes |
def test_save_load_network(self): local_net = Net_arch(self.hp) self.loss_f = nn.MSELoss() local_model = Model(self.hp, local_net, self.loss_f) self.model.save_network(self.logger) save_filename = "%s_%d.pt" % (self.hp.log.name, self.model.step) save_path = os.path.join(self.hp.log.chkpt_dir, save_filename) self.hp.load.network_chkpt_path = save_path assert os.path.exists(save_path) and os.path.isfile(save_path) assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile( self.hp.log.log_file_path ) local_model.load_network(logger=self.logger) parameters = zip( list(local_model.net.parameters()), list(self.model.net.parameters()) ) for load, origin in parameters: assert (load == origin).all()
Example #2
Source File: model_test.py From pytorch-project-template with Apache License 2.0 | 6 votes |
def test_save_load_state(self): local_net = Net_arch(self.hp) self.loss_f = nn.MSELoss() local_model = Model(self.hp, local_net, self.loss_f) self.model.save_training_state(self.logger) save_filename = "%s_%d.state" % (self.hp.log.name, self.model.step) save_path = os.path.join(self.hp.log.chkpt_dir, save_filename) self.hp.load.resume_state_path = save_path assert os.path.exists(save_path) and os.path.isfile(save_path) assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile( self.hp.log.log_file_path ) local_model.load_training_state(logger=self.logger) parameters = zip( list(local_model.net.parameters()), list(self.model.net.parameters()) ) for load, origin in parameters: assert (load == origin).all() assert local_model.epoch == self.model.epoch assert local_model.step == self.model.step
Example #3
Source File: train.py From CVPR2019-DeepTreeLearningForZeroShotFaceAntispoofing with MIT License | 6 votes |
def main(argv=None): # Configurations config = Config() config.DATA_DIR = ['/data/'] config.LOG_DIR = './log/model' config.MODE = 'training' config.STEPS_PER_EPOCH_VAL = 180 config.display() # Get images and labels. dataset_train = Dataset(config, 'train') # Build a Graph model = Model(config) # Train the model model.compile() model.train(dataset_train, None)
Example #4
Source File: mono_3d_estimation.py From 3d-vehicle-tracking with BSD 3-Clause "New" or "Revised" License | 6 votes |
def main(): torch.set_num_threads(multiprocessing.cpu_count()) args = parse_args() if args.set == 'gta': from model.model import Model elif args.set == 'kitti': from model.model_cen import Model else: raise ValueError("Model not found") model = Model(args.arch, args.roi_name, args.down_ratio, args.roi_kernel) model = nn.DataParallel(model) model = model.to(args.device) if args.phase == 'train': run_training(model, args) elif args.phase == 'test': test_model(model, args)
Example #5
Source File: model_test.py From pytorch-project-template with Apache License 2.0 | 5 votes |
def setup_method(self, method): super(TestModel, self).setup_method() self.net = Net_arch(self.hp) self.loss_f = nn.CrossEntropyLoss() self.model = Model(self.hp, self.net, self.loss_f)
Example #6
Source File: __init__.py From hart with GNU General Public License v3.0 | 5 votes |
def convert_layer_to_tensor(layer, dtype=None, name=None, as_ref=False): if not isinstance(layer, (Layer, Model)): return NotImplemented return layer.output
Example #7
Source File: test.py From multiwoz with MIT License | 5 votes |
def loadModelAndData(num): # Load dictionaries with open('data/input_lang.index2word.json') as f: input_lang_index2word = json.load(f) with open('data/input_lang.word2index.json') as f: input_lang_word2index = json.load(f) with open('data/output_lang.index2word.json') as f: output_lang_index2word = json.load(f) with open('data/output_lang.word2index.json') as f: output_lang_word2index = json.load(f) # Reload existing checkpoint model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index) if args.load_param: model.loadModel(iter=num) # Load data if os.path.exists(args.decode_output): shutil.rmtree(args.decode_output) os.makedirs(args.decode_output) else: os.makedirs(args.decode_output) if os.path.exists(args.valid_output): shutil.rmtree(args.valid_output) os.makedirs(args.valid_output) else: os.makedirs(args.valid_output) # Load validation file list: with open('data/val_dials.json') as outfile: val_dials = json.load(outfile) # Load test file list: with open('data/test_dials.json') as outfile: test_dials = json.load(outfile) return model, val_dials, test_dials
Example #8
Source File: demo.py From conv-ensemble-str with Apache License 2.0 | 5 votes |
def main(): parser = argparse.ArgumentParser() parser.add_argument('--path', type=str, required=True, help='path to image file.') parser.add_argument('--checkpoint', type=str, default='data/model.ckpt', help='path to image file.') args = parser.parse_args() params = { 'checkpoint': args.checkpoint, 'dataset':{ 'dataset_dir': 'data', 'charset_filename': 'charset_size=63.txt', 'max_sequence_length': 30, }, 'beam_width': 1, 'summary': False } model = Model(params, ModeKeys.INFER) image = tf.placeholder(tf.uint8, (1, 32, 100, 3), name='image') predictions, _, _ = model({'image': image}, None) assert os.path.exists(args.path), '%s does not exists!' % args.path raw_image = Image.open(args.path).convert('RGB') raw_image = raw_image.resize((100, 32), Image.BILINEAR) raw_image = np.array(raw_image)[None, :] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) predictions = sess.run(predictions, feed_dict={image: raw_image}) text = predictions['predicted_text'][0] print('%s: %s' % (args.path, text))
Example #9
Source File: classifier.py From VerifAI with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, classifier_data): port = classifier_data.port bufsize = classifier_data.bufsize super().__init__(port, bufsize) self.sess = tf.Session() self.nn = Model() self.nn.init(classifier_data.graph_path, classifier_data.checkpoint_path, self.sess) self.lib = getLib()
Example #10
Source File: launcher.py From Attention-OCR with MIT License | 5 votes |
def main(args, defaults): parameters = process_args(args, defaults) logging.basicConfig( level=logging.DEBUG, format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s', filename=parameters.log_path) console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: model = Model( phase = parameters.phase, visualize = parameters.visualize, data_path = parameters.data_path, data_base_dir = parameters.data_base_dir, output_dir = parameters.output_dir, batch_size = parameters.batch_size, initial_learning_rate = parameters.initial_learning_rate, num_epoch = parameters.num_epoch, steps_per_checkpoint = parameters.steps_per_checkpoint, target_vocab_size = parameters.target_vocab_size, model_dir = parameters.model_dir, target_embedding_size = parameters.target_embedding_size, attn_num_hidden = parameters.attn_num_hidden, attn_num_layers = parameters.attn_num_layers, clip_gradients = parameters.clip_gradients, max_gradient_norm = parameters.max_gradient_norm, load_model = parameters.load_model, valid_target_length = float('inf'), gpu_id=parameters.gpu_id, use_gru=parameters.use_gru, session = sess) model.launch()
Example #11
Source File: decode.py From speaker_extraction with GNU General Public License v3.0 | 4 votes |
def decode(): tfrecords_list, num_batches = read_list(FLAGS.lists_dir, FLAGS.data_type, FLAGS.batch_size) with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): cmvn = np.load(FLAGS.inputs_cmvn) cmvn_aux = np.load(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux')) if FLAGS.with_labels: inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size, FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels, num_enqueuing_threads=1, num_epochs=1, shuffle=False) else: inputs, inputs_cmvn, inputs_cmvn_aux, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size, FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels, num_enqueuing_threads=1, num_epochs=1, shuffle=False) labels = None with tf.name_scope('model'): model = Model(FLAGS, inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux, infer=True) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init) checkpoint = tf.train.get_checkpoint_state(FLAGS.save_model_dir) if checkpoint and checkpoint.model_checkpoint_path: tf.logging.info("Restore best model from " + checkpoint.model_checkpoint_path) model.saver.restore(sess, checkpoint.model_checkpoint_path) else: tf.logging.fatal("Checkpoint is not found, please check the best model save path is correct.") sys.exit(-1) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for batch in xrange(num_batches): if coord.should_stop(): break sep, mag_lengths = sess.run([model._sep, model._lengths]) for i in xrange(FLAGS.batch_size): filename = tfrecords_list[FLAGS.batch_size*batch+i] (_, name) = os.path.split(filename) (uttid, _) = os.path.splitext(name) noisy_file = os.path.join(FLAGS.noisy_dir, uttid + '.wav') enhan_sig, rate = reconstruct(np.squeeze(sep[i,:mag_lengths[i],:]), noisy_file) savepath = os.path.join(FLAGS.rec_dir, uttid + '.wav') wav.write(savepath, rate, enhan_sig) if (batch+1) % 100 == 0: tf.logging.info("Number of batch processed: %d." % (batch+1)) except Exception, e: coord.request_stop(e) finally: