Python model.model.Model() Examples

The following are 11 code examples of model.model.Model(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module model.model , or try the search function .
Example #1
Source File: model_test.py    From pytorch-project-template with Apache License 2.0 6 votes vote down vote up
def test_save_load_network(self):
        local_net = Net_arch(self.hp)
        self.loss_f = nn.MSELoss()
        local_model = Model(self.hp, local_net, self.loss_f)

        self.model.save_network(self.logger)
        save_filename = "%s_%d.pt" % (self.hp.log.name, self.model.step)
        save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
        self.hp.load.network_chkpt_path = save_path

        assert os.path.exists(save_path) and os.path.isfile(save_path)
        assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
            self.hp.log.log_file_path
        )

        local_model.load_network(logger=self.logger)
        parameters = zip(
            list(local_model.net.parameters()), list(self.model.net.parameters())
        )
        for load, origin in parameters:
            assert (load == origin).all() 
Example #2
Source File: model_test.py    From pytorch-project-template with Apache License 2.0 6 votes vote down vote up
def test_save_load_state(self):
        local_net = Net_arch(self.hp)
        self.loss_f = nn.MSELoss()
        local_model = Model(self.hp, local_net, self.loss_f)

        self.model.save_training_state(self.logger)
        save_filename = "%s_%d.state" % (self.hp.log.name, self.model.step)
        save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
        self.hp.load.resume_state_path = save_path

        assert os.path.exists(save_path) and os.path.isfile(save_path)
        assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
            self.hp.log.log_file_path
        )

        local_model.load_training_state(logger=self.logger)
        parameters = zip(
            list(local_model.net.parameters()), list(self.model.net.parameters())
        )
        for load, origin in parameters:
            assert (load == origin).all()
        assert local_model.epoch == self.model.epoch
        assert local_model.step == self.model.step 
Example #3
Source File: train.py    From CVPR2019-DeepTreeLearningForZeroShotFaceAntispoofing with MIT License 6 votes vote down vote up
def main(argv=None):
    # Configurations
    config = Config()
    config.DATA_DIR = ['/data/']
    config.LOG_DIR = './log/model'
    config.MODE = 'training'
    config.STEPS_PER_EPOCH_VAL = 180
    config.display()

    # Get images and labels.
    dataset_train = Dataset(config, 'train')
    # Build a Graph
    model = Model(config)

    # Train the model
    model.compile()
    model.train(dataset_train, None) 
Example #4
Source File: mono_3d_estimation.py    From 3d-vehicle-tracking with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def main():
    torch.set_num_threads(multiprocessing.cpu_count())
    args = parse_args()
    if args.set == 'gta':
        from model.model import Model
    elif args.set == 'kitti':
        from model.model_cen import Model
    else:
        raise ValueError("Model not found")

    model = Model(args.arch,
                  args.roi_name,
                  args.down_ratio,
                  args.roi_kernel)
    model = nn.DataParallel(model)
    model = model.to(args.device)

    if args.phase == 'train':
        run_training(model, args)
    elif args.phase == 'test':
        test_model(model, args) 
Example #5
Source File: model_test.py    From pytorch-project-template with Apache License 2.0 5 votes vote down vote up
def setup_method(self, method):
        super(TestModel, self).setup_method()
        self.net = Net_arch(self.hp)
        self.loss_f = nn.CrossEntropyLoss()
        self.model = Model(self.hp, self.net, self.loss_f) 
Example #6
Source File: __init__.py    From hart with GNU General Public License v3.0 5 votes vote down vote up
def convert_layer_to_tensor(layer, dtype=None, name=None, as_ref=False):
    if not isinstance(layer, (Layer, Model)):
        return NotImplemented
    return layer.output 
Example #7
Source File: test.py    From multiwoz with MIT License 5 votes vote down vote up
def loadModelAndData(num):
    # Load dictionaries
    with open('data/input_lang.index2word.json') as f:
        input_lang_index2word = json.load(f)
    with open('data/input_lang.word2index.json') as f:
        input_lang_word2index = json.load(f)
    with open('data/output_lang.index2word.json') as f:
        output_lang_index2word = json.load(f)
    with open('data/output_lang.word2index.json') as f:
        output_lang_word2index = json.load(f)

    # Reload existing checkpoint
    model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index)
    if args.load_param:
        model.loadModel(iter=num)

    # Load data
    if os.path.exists(args.decode_output):
        shutil.rmtree(args.decode_output)
        os.makedirs(args.decode_output)
    else:
        os.makedirs(args.decode_output)

    if os.path.exists(args.valid_output):
        shutil.rmtree(args.valid_output)
        os.makedirs(args.valid_output)
    else:
        os.makedirs(args.valid_output)

    # Load validation file list:
    with open('data/val_dials.json') as outfile:
        val_dials = json.load(outfile)

    # Load test file list:
    with open('data/test_dials.json') as outfile:
        test_dials = json.load(outfile)
    return model, val_dials, test_dials 
Example #8
Source File: demo.py    From conv-ensemble-str with Apache License 2.0 5 votes vote down vote up
def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--path', type=str, required=True,
                      help='path to image file.')
  parser.add_argument('--checkpoint', type=str, default='data/model.ckpt',
                      help='path to image file.')
  args = parser.parse_args()

  params = {
    'checkpoint': args.checkpoint,
    'dataset':{
      'dataset_dir': 'data',
      'charset_filename': 'charset_size=63.txt',
      'max_sequence_length': 30,
    },
    'beam_width': 1,
    'summary': False
  }
  model = Model(params, ModeKeys.INFER)
  image = tf.placeholder(tf.uint8, (1, 32, 100, 3), name='image')
  predictions, _, _ = model({'image': image}, None)

  assert os.path.exists(args.path), '%s does not exists!' % args.path
  raw_image = Image.open(args.path).convert('RGB')
  raw_image = raw_image.resize((100, 32), Image.BILINEAR)
  raw_image = np.array(raw_image)[None, :]

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    predictions = sess.run(predictions, feed_dict={image: raw_image})
    text = predictions['predicted_text'][0]
    print('%s: %s' % (args.path, text)) 
Example #9
Source File: classifier.py    From VerifAI with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, classifier_data):
		port = classifier_data.port
		bufsize = classifier_data.bufsize
		super().__init__(port, bufsize)
		self.sess = tf.Session()
		self.nn = Model()
		self.nn.init(classifier_data.graph_path, classifier_data.checkpoint_path, self.sess)
		self.lib = getLib() 
Example #10
Source File: launcher.py    From Attention-OCR with MIT License 5 votes vote down vote up
def main(args, defaults):
    parameters = process_args(args, defaults)
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
        filename=parameters.log_path)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        model = Model(
                phase = parameters.phase,
                visualize = parameters.visualize,
                data_path = parameters.data_path,
                data_base_dir = parameters.data_base_dir,
                output_dir = parameters.output_dir,
                batch_size = parameters.batch_size,
                initial_learning_rate = parameters.initial_learning_rate,
                num_epoch = parameters.num_epoch,
                steps_per_checkpoint = parameters.steps_per_checkpoint,
                target_vocab_size = parameters.target_vocab_size, 
                model_dir = parameters.model_dir,
                target_embedding_size = parameters.target_embedding_size,
                attn_num_hidden = parameters.attn_num_hidden,
                attn_num_layers = parameters.attn_num_layers,
                clip_gradients = parameters.clip_gradients,
                max_gradient_norm = parameters.max_gradient_norm,
                load_model = parameters.load_model,
                valid_target_length = float('inf'),
                gpu_id=parameters.gpu_id,
                use_gru=parameters.use_gru,
                session = sess)
        model.launch() 
Example #11
Source File: decode.py    From speaker_extraction with GNU General Public License v3.0 4 votes vote down vote up
def decode():
    tfrecords_list, num_batches = read_list(FLAGS.lists_dir, FLAGS.data_type, FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                cmvn = np.load(FLAGS.inputs_cmvn)
                cmvn_aux = np.load(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux'))
                if FLAGS.with_labels:
                    inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size,
                        FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels, 
                        num_enqueuing_threads=1, num_epochs=1, shuffle=False)
                else:
                    inputs, inputs_cmvn, inputs_cmvn_aux, lengths, lengths_aux = paddedFIFO_batch(tfrecords_list, FLAGS.batch_size,
                        FLAGS.input_size, FLAGS.output_size, cmvn=cmvn, cmvn_aux=cmvn_aux, with_labels=FLAGS.with_labels,
                        num_enqueuing_threads=1, num_epochs=1, shuffle=False)
                    labels = None
               
        with tf.name_scope('model'):
            model = Model(FLAGS, inputs, inputs_cmvn, inputs_cmvn_aux, labels, lengths, lengths_aux, infer=True)

        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess = tf.Session()
        sess.run(init)

        checkpoint = tf.train.get_checkpoint_state(FLAGS.save_model_dir)
        if checkpoint and checkpoint.model_checkpoint_path:
            tf.logging.info("Restore best model from " + checkpoint.model_checkpoint_path)
            model.saver.restore(sess, checkpoint.model_checkpoint_path)
        else:
            tf.logging.fatal("Checkpoint is not found, please check the best model save path is correct.")
            sys.exit(-1)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            for batch in xrange(num_batches):
                if coord.should_stop():
                    break

                sep, mag_lengths = sess.run([model._sep, model._lengths])
                for i in xrange(FLAGS.batch_size):
                    filename = tfrecords_list[FLAGS.batch_size*batch+i]
                    (_, name) = os.path.split(filename)
                    (uttid, _) = os.path.splitext(name)
                    noisy_file = os.path.join(FLAGS.noisy_dir, uttid + '.wav')
                    enhan_sig, rate = reconstruct(np.squeeze(sep[i,:mag_lengths[i],:]), noisy_file)
                    savepath = os.path.join(FLAGS.rec_dir, uttid + '.wav')
                    wav.write(savepath, rate, enhan_sig)

                if (batch+1) % 100 == 0:
                    tf.logging.info("Number of batch processed: %d." % (batch+1))

        except Exception, e:
            coord.request_stop(e)
        finally: