Python keras.models.model_from_config() Examples

The following are 4 code examples of keras.models.model_from_config(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module keras.models , or try the search function .
Example #1
Source File: util.py    From keras-rl with MIT License 6 votes vote down vote up
def clone_model(model, custom_objects={}):
    # Requires Keras 1.0.7 since get_config has breaking changes.
    config = {
        'class_name': model.__class__.__name__,
        'config': model.get_config(),
    }
    clone = model_from_config(config, custom_objects=custom_objects)
    clone.set_weights(model.get_weights())
    return clone 
Example #2
Source File: base.py    From crema with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def _instantiate(self, rsc):

        # First, load the pump
        with open(resource_filename(__name__,
                                    os.path.join(rsc, 'pump.pkl')),
                  'rb') as fd:
            self.pump = pickle.load(fd)

        # Now load the model
        with open(resource_filename(__name__,
                                    os.path.join(rsc, 'model_spec.pkl')),
                  'rb') as fd:
            spec = pickle.load(fd)
            self.model = model_from_config(spec,
                                           custom_objects={k: layers.__dict__[k]
                                                           for k in layers.__all__})

        # And the model weights
        self.model.load_weights(resource_filename(__name__,
                                                  os.path.join(rsc,
                                                               'model.h5')))

        # And the version number
        with open(resource_filename(__name__,
                                    os.path.join(rsc, 'version.txt')),
                  'r') as fd:
            self.version = fd.read().strip() 
Example #3
Source File: util.py    From openai_lab with MIT License 5 votes vote down vote up
def clone_model(model, custom_objects=None):
    from keras.models import model_from_config
    custom_objects = custom_objects or {}
    config = {
        'class_name': model.__class__.__name__,
        'config': model.get_config(),
    }
    clone = model_from_config(config, custom_objects=custom_objects)
    clone.set_weights(model.get_weights())
    return clone


# clone a keras optimizer without file I/O 
Example #4
Source File: convertkeras.py    From keras_to_tensorflow with MIT License 4 votes vote down vote up
def convert(prevmodel,export_path,freeze_graph_binary):

   # open up a Tensorflow session
   sess = tf.Session()
   # tell Keras to use the session
   K.set_session(sess)

   # From this document: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html
   
   # let's convert the model for inference
   K.set_learning_phase(0)  # all new operations will be in test mode from now on
   # serialize the model and get its weights, for quick re-building
   previous_model = load_model(prevmodel)
   previous_model.summary()

   config = previous_model.get_config()
   weights = previous_model.get_weights()

   # re-build a model where the learning phase is now hard-coded to 0
   try:
     model= Sequential.from_config(config) 
   except:
     model= Model.from_config(config) 
   #model= model_from_config(config)
   model.set_weights(weights)

   print("Input name:")
   print(model.input.name)
   print("Output name:")
   print(model.output.name)
   output_name=model.output.name.split(':')[0]

   #  not sure what this is for
   export_version = 1 # version number (integer)

   graph_file=export_path+"_graph.pb"
   ckpt_file=export_path+".ckpt"
   # create a saver 
   saver = tf.train.Saver(sharded=True)
   tf.train.write_graph(sess.graph_def, '', graph_file)
   save_path = saver.save(sess, ckpt_file)
#~/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=./graph.pb  --input_checkpoint=./model.ckpt --output_node_names=add_72 --output_graph=frozen.pb
   command = freeze_graph_binary +" --input_graph=./"+graph_file+" --input_checkpoint=./"+ckpt_file+" --output_node_names="+output_name+" --output_graph=./"+export_path+".pb"
   print(command)
   os.system(command)