Python copy.copy() Examples

The following are 30 code examples of copy.copy(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module copy , or try the search function .
Example #1
Source File: replay_memory.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def copy(self):
        # TODO Test the copy function
        replay_memory = copy.copy(self)
        replay_memory.states = numpy.zeros(self.states.shape, dtype=self.states.dtype)
        replay_memory.actions = numpy.zeros(self.actions.shape, dtype=self.actions.dtype)
        replay_memory.rewards = numpy.zeros(self.rewards.shape, dtype='float32')
        replay_memory.terminate_flags = numpy.zeros(self.terminate_flags.shape, dtype='bool')
        replay_memory.states[numpy.arange(self.top-self.size, self.top), ::] = \
            self.states[numpy.arange(self.top-self.size, self.top)]
        replay_memory.actions[numpy.arange(self.top-self.size, self.top)] = \
            self.actions[numpy.arange(self.top-self.size, self.top)]
        replay_memory.rewards[numpy.arange(self.top-self.size, self.top)] = \
            self.rewards[numpy.arange(self.top-self.size, self.top)]
        replay_memory.terminate_flags[numpy.arange(self.top-self.size, self.top)] = \
            self.terminate_flags[numpy.arange(self.top-self.size, self.top)]
        return replay_memory 
Example #2
Source File: lstm.py    From fine-lm with MIT License 6 votes vote down vote up
def lstm_seq2seq_internal_attention_bid_encoder(inputs, targets, hparams,
                                                train):
  """LSTM seq2seq model with attention, main step used for training."""
  with tf.variable_scope("lstm_seq2seq_attention_bid_encoder"):
    inputs_length = common_layers.length_from_embedding(inputs)
    # Flatten inputs.
    inputs = common_layers.flatten4d3d(inputs)
    # LSTM encoder.
    encoder_outputs, final_encoder_state = lstm_bid_encoder(
        inputs, inputs_length, hparams, train, "encoder")
    # LSTM decoder with attention
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = common_layers.length_from_embedding(shifted_targets) + 1
    hparams_decoder = copy.copy(hparams)
    hparams_decoder.hidden_size = 2 * hparams.hidden_size
    decoder_outputs = lstm_attention_decoder(
        common_layers.flatten4d3d(shifted_targets), hparams_decoder, train,
        "decoder", final_encoder_state, encoder_outputs,
        inputs_length, targets_length)
    return tf.expand_dims(decoder_outputs, axis=2) 
Example #3
Source File: bbs.py    From cat-bbs with MIT License 6 votes vote down vote up
def add_border(self, val, img_shape=None):
        if val == 0:
            return self.copy()
        else:
            if isinstance(val, int):
                rect = Rectangle(x1=self.x1-val, x2=self.x2+val, y1=self.y1-val, y2=self.y2+val)
            elif isinstance(val, float):
                rect = Rectangle(x1=int(self.x1 - self.width*val), x2=int(self.x2 + self.width*val), y1=int(self.y1 - self.height*val), y2=int(self.y2 + self.height*val))
            elif isinstance(val, tuple):
                assert len(val) == 4, str(len(val))

                if all([isinstance(subval, int) for subval in val]):
                    rect = Rectangle(x1=self.x1-val[3], x2=self.x2+val[1], y1=self.y1-val[0], y2=self.y2+val[2])
                elif all([isinstance(subval, float) or subval == 0 for subval in val]): # "or subval==0" da sonst zB (0.1, 0, 0.1, 0) einen fehler erzeugt (0 ist int)
                    rect = Rectangle(x1=int(self.x1 - self.width*val[3]), x2=int(self.x2 + self.width*val[1]), y1=int(self.y1 - self.height*val[0]), y2=int(self.y2 + self.height*val[2]))
                else:
                    raise Exception("Tuple of all ints or tuple of all floats expected, got %s" % (str([type(v) for v in val]),))
            else:
                raise Exception("int or float or tuple of ints/floats expected, got %s" % (type(val),))

            if img_shape is not None:
                rect.fix_by_image_dimensions(height=img_shape[0], width=img_shape[1])

            return rect 
Example #4
Source File: tf_util.py    From lirpg with MIT License 6 votes vote down vote up
def switch(condition, then_expression, else_expression):
    """Switches between two operations depending on a scalar value (int or bool).
    Note that both `then_expression` and `else_expression`
    should be symbolic tensors of the *same shape*.

    # Arguments
        condition: scalar tensor.
        then_expression: TensorFlow operation.
        else_expression: TensorFlow operation.
    """
    x_shape = copy.copy(then_expression.get_shape())
    x = tf.cond(tf.cast(condition, 'bool'),
                lambda: then_expression,
                lambda: else_expression)
    x.set_shape(x_shape)
    return x

# ================================================================
# Extras
# ================================================================ 
Example #5
Source File: ddpg.py    From lirpg with MIT License 6 votes vote down vote up
def setup_param_noise(self, normalized_obs0):
        assert self.param_noise is not None

        # Configure perturbed actor.
        param_noise_actor = copy(self.actor)
        param_noise_actor.name = 'param_noise_actor'
        self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
        logger.info('setting up param noise')
        self.perturb_policy_ops = get_perturbed_actor_updates(self.actor, param_noise_actor, self.param_noise_stddev)

        # Configure separate copy for stddev adoption.
        adaptive_param_noise_actor = copy(self.actor)
        adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
        adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
        self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
        self.adaptive_policy_distance = tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf))) 
Example #6
Source File: ddpg.py    From lirpg with MIT License 6 votes vote down vote up
def adapt_param_noise(self):
        if self.param_noise is None:
            return 0.

        # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
        batch = self.memory.sample(batch_size=self.batch_size)
        self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
            self.param_noise_stddev: self.param_noise.current_stddev,
        })
        distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
            self.obs0: batch['obs0'],
            self.param_noise_stddev: self.param_noise.current_stddev,
        })

        mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
        self.param_noise.adapt(mean_distance)
        return mean_distance 
Example #7
Source File: __main__.py    From vergeml with MIT License 6 votes vote down vote up
def _forgive_wrong_option_order(argv):
    first_part = []
    second_part = []
    rest = copy(argv)

    while rest:
        arg = rest.pop(0)

        if arg.startswith("--"):
            argname = arg.lstrip("--")
            if "=" in argname:
                argname = argname.split("=")[0]
            is_vergeml_opt = bool(argname in _VERGEML_OPTION_NAMES)
            lst = (first_part if is_vergeml_opt else second_part)

            if arg.endswith("=") or not "=" in arg:
                if not rest:
                    # give up
                    second_part.append(arg)
                else:
                    lst.append(arg)
                    lst.append(rest.pop(0))
            else:
                lst.append(arg)

        else:
            second_part.append(arg)

    return first_part + second_part 
Example #8
Source File: lex.py    From SublimeKSP with GNU General Public License v3.0 6 votes vote down vote up
def get_caller_module_dict(levels):
    try:
        raise RuntimeError
    except RuntimeError:
        e,b,t = sys.exc_info()
        f = t.tb_frame
        while levels > 0:
            f = f.f_back                   
            levels -= 1
        ldict = f.f_globals.copy()
        if f.f_globals != f.f_locals:
            ldict.update(f.f_locals)

        return ldict

# -----------------------------------------------------------------------------
# _funcs_to_names()
#
# Given a list of regular expression functions, this converts it to a list
# suitable for output to a table file
# ----------------------------------------------------------------------------- 
Example #9
Source File: composite.py    From indras_net with GNU General Public License v3.0 6 votes vote down vote up
def __add__(self, other):
        """
        This implements set union and returns
        a new Composite that is self union other.
        If other is an atomic agent, just add it to
        this group.
        """
        if other is None:
            return self

        new_dict = copy(self.members)
        if is_composite(other):
            new_dict.update(other.members)
        else:
            new_dict[other.name] = other
        new_grp = grp_from_nm_dict(self.name + "+" + other.name, new_dict)
        self.add_group(new_grp)
        other.add_group(new_grp)
        return new_grp 
Example #10
Source File: image_transformer_2d.py    From fine-lm with MIT License 6 votes vote down vote up
def body(self, features):
    hparams = copy.copy(self._hparams)
    inputs = features["inputs"]
    targets = features["targets"]
    targets_shape = common_layers.shape_list(targets)
    if not (tf.get_variable_scope().reuse or
            hparams.mode == tf.contrib.learn.ModeKeys.INFER):
      tf.summary.image("targets", targets, max_outputs=1)

    decoder_input, rows, cols = cia.prepare_decoder(
        targets, hparams)
    # Add class label to decoder input.
    if not hparams.unconditional:
      decoder_input += tf.reshape(inputs,
                                  [targets_shape[0], 1, 1, hparams.hidden_size])

    decoder_output = cia.transformer_decoder_layers(
        decoder_input, None,
        hparams.num_decoder_layers,
        hparams,
        attention_type=hparams.dec_attention_type,
        name="decoder")

    output = cia.create_output(decoder_output, rows, cols, targets, hparams)
    return output 
Example #11
Source File: bbs.py    From cat-bbs with MIT License 6 votes vote down vote up
def draw_on_image(self, img, color=[0, 255, 0], alpha=1.0, copy=True, from_img=None):
        if copy:
            img = np.copy(img)

        orig_dtype = img.dtype
        if alpha != 1.0 and img.dtype != np.float32:
            img = img.astype(np.float32, copy=False)

        for rect in self:
            if from_img is not None:
                rect.resize(from_img, img).draw_on_image(img, color=color, alpha=alpha, copy=False)
            else:
                rect.draw_on_image(img, color=color, alpha=alpha, copy=False)

        if orig_dtype != img.dtype:
            img = img.astype(orig_dtype, copy=False)

        return img 
Example #12
Source File: GameOfLife.py    From BiblioPixelAnimations with MIT License 6 votes vote down vote up
def turn(self):
        """Turn"""
        nt = copy.deepcopy(self.table)
        for y in range(0, self.height):
            for x in range(0, self.width):
                neighbours = self.liveNeighbours(y, x)
                if self.table[y][x] == 0:
                    if neighbours == 3:
                        nt[y][x] = 1
                else:
                    if (neighbours < 2) or (neighbours > 3):
                        nt[y][x] = 0

        self._oldStates.append(self.table)
        if len(self._oldStates) > 3:
            self._oldStates.popleft()

        self.table = nt 
Example #13
Source File: zmirror.py    From zmirror with MIT License 6 votes vote down vote up
def generate_our_response():
    """
    生成我们的响应
    :rtype: Response
    """
    # copy and parse remote response
    resp = copy_response(is_streamed=parse.streamed_our_response)

    if parse.time["req_time_header"] >= 0.00001:
        parse.set_extra_resp_header('X-Header-Req-Time', "%.4f" % parse.time["req_time_header"])
    if parse.time.get("start_time") is not None and not parse.streamed_our_response:
        # remote request time should be excluded when calculating total time
        parse.set_extra_resp_header('X-Body-Req-Time', "%.4f" % parse.time["req_time_body"])
        parse.set_extra_resp_header('X-Compute-Time',
                                    "%.4f" % (process_time() - parse.time["start_time"]))

    parse.set_extra_resp_header('X-Powered-By', 'zmirror/%s' % CONSTS.__VERSION__)

    if developer_dump_all_traffics and not parse.streamed_our_response:
        dump_zmirror_snapshot("traffic")

    return resp 
Example #14
Source File: tf_util.py    From cs294-112_hws with MIT License 6 votes vote down vote up
def switch(condition, then_expression, else_expression):
    '''Switches between two operations depending on a scalar value (int or bool).
    Note that both `then_expression` and `else_expression`
    should be symbolic tensors of the *same shape*.

    # Arguments
        condition: scalar tensor.
        then_expression: TensorFlow operation.
        else_expression: TensorFlow operation.
    '''
    x_shape = copy.copy(then_expression.get_shape())
    x = tf.cond(tf.cast(condition, 'bool'),
                lambda: then_expression,
                lambda: else_expression)
    x.set_shape(x_shape)
    return x

# Extras
# ---------------------------------------- 
Example #15
Source File: block.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
                   force_reinit=False):
        """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
        Equivalent to ``block.collect_params().initialize(...)``

        Parameters
        ----------
        init : Initializer
            Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``.
            Otherwise, :py:meth:`Parameter.init` takes precedence.
        ctx : Context or list of Context
            Keeps a copy of Parameters on one or many context(s).
        verbose : bool, default False
            Whether to verbosely print out details on initialization.
        force_reinit : bool, default False
            Whether to force re-initialization if parameter is already initialized.
        """
        self.collect_params().initialize(init, ctx, verbose, force_reinit) 
Example #16
Source File: missing_value.py    From goodtables-py with MIT License 6 votes vote down vote up
def missing_value(cells):
    """
    missing-value: 	A row has less columns than the header.
    """
    errors = []

    for cell in copy(cells):

        # Skip if cell has value
        # There is a difference between:
        # - not having value at all - there is no `value` key
        # - having a value which is falsy (None, False, '', etc)
        # (so we don't use something like `if cell.get('value')`)
        if 'value' in cell or cell.get('is-virtual'):
            continue

        # Add error
        error = Error('missing-value', cell)
        errors.append(error)

        # Remove cell
        cells.remove(cell)

    return errors 
Example #17
Source File: non_matching_header.py    From goodtables-py with MIT License 6 votes vote down vote up
def _check_without_ordering(cells):
    errors = []

    for cell in copy(cells):
        if cell.get('field') is not None:
            header = cell.get('header')
            if header != cell['field'].name and header is not None:
                # Add error
                message_substitutions = {
                    'field_name': '"{}"'.format(cell['field'].name),
                    'header': '"{}"'.format(cell.get('header')),
                }
                error = Error(
                    'non-matching-header',
                    cell,
                    message_substitutions=message_substitutions
                )
                errors.append(error)
                if _slugify(header) != _slugify(cell['field'].name):
                    # Remove cell
                    cells.remove(cell)

    return errors 
Example #18
Source File: filemap.py    From neuropythy with GNU Affero General Public License v3.0 6 votes vote down vote up
def _check_tarball(self, *path_parts):
        rpath = self.join(*path_parts)
        # start by checking our base path:
        if self.base_path is not None and self.base_path != '':
            (tbloc, tbinternal) = split_tarball_path(self.base_path)
            if tbloc is not None:
                if tbinternal == '':
                    # we're fine; we just need to cache the current file...
                    return (self._cache_tarball(''), rpath)
                else:
                    # We copy ourselves to handle this base-path
                    tmp = copy.copy(self)
                    object.__setattr__(tmp, 'base_path', tbloc)
                    rpath = self.join(tbinternal, rpath)
                    # we defer to this path object with the new relative path:
                    return (tmp._cache_tarball(''), rpath)
        # okay, next check the relative path
        fpath = self.join('' if self.base_path is None else self.base_path, rpath)
        (tbloc, tbinternal) = split_tarball_path(fpath)
        if tbloc is not None:
            tbp = self._cache_tarball(tbloc)
            return (tbp, tbinternal)
        # otherwise, we have no tarball on the path and just need to return ourselves as we are:
        return (self, rpath) 
Example #19
Source File: utils.py    From pypika with Apache License 2.0 6 votes vote down vote up
def builder(func: Callable) -> Callable:
    """
    Decorator for wrapper "builder" functions.  These are functions on the Query class or other classes used for
    building queries which mutate the query and return self.  To make the build functions immutable, this decorator is
    used which will deepcopy the current instance.  This decorator will return the return value of the inner function
    or the new copy of the instance.  The inner function does not need to return self.
    """
    import copy

    def _copy(self, *args, **kwargs):
        self_copy = copy.copy(self) if getattr(self, "immutable", True) else self
        result = func(self_copy, *args, **kwargs)

        # Return self if the inner function returns None.  This way the inner function can return something
        # different (for example when creating joins, a different builder is returned).
        if result is None:
            return self_copy

        return result

    return _copy 
Example #20
Source File: fake_cloud_client_test.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_deep_copy(self):
    entity1 = fake_cloud_client.make_entity(('abc', '1'))
    entity1['k1'] = ['v1']
    self.assertEqual(entity1.key,
                     fake_cloud_client.FakeDatastoreKey('abc', '1'))
    self.assertEqual(dict(entity1),
                     {'k1': ['v1']})
    entity2 = copy.deepcopy(entity1)
    entity2['k1'].append('v2')
    entity2['k3'] = 'v3'
    self.assertIsInstance(entity2, fake_cloud_client.FakeDatastoreEntity)
    self.assertEqual(entity1.key,
                     fake_cloud_client.FakeDatastoreKey('abc', '1'))
    self.assertEqual(dict(entity1),
                     {'k1': ['v1']})
    self.assertEqual(entity2.key,
                     fake_cloud_client.FakeDatastoreKey('abc', '1'))
    self.assertEqual(dict(entity2),
                     {'k1': ['v1', 'v2'], 'k3': 'v3'}) 
Example #21
Source File: fake_cloud_client_test.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_copy(self):
    entity1 = fake_cloud_client.make_entity(('abc', '1'))
    entity1['k1'] = ['v1']
    self.assertEqual(entity1.key,
                     fake_cloud_client.FakeDatastoreKey('abc', '1'))
    self.assertEqual(dict(entity1),
                     {'k1': ['v1']})
    entity2 = copy.copy(entity1)
    entity2['k1'].append('v2')
    entity2['k3'] = 'v3'
    self.assertIsInstance(entity2, fake_cloud_client.FakeDatastoreEntity)
    self.assertEqual(entity1.key,
                     fake_cloud_client.FakeDatastoreKey('abc', '1'))
    self.assertEqual(dict(entity1),
                     {'k1': ['v1', 'v2']})
    self.assertEqual(entity2.key,
                     fake_cloud_client.FakeDatastoreKey('abc', '1'))
    self.assertEqual(dict(entity2),
                     {'k1': ['v1', 'v2'], 'k3': 'v3'}) 
Example #22
Source File: queries.py    From pypika with Apache License 2.0 6 votes vote down vote up
def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "JoinUsing":
        """
        Replaces all occurrences of the specified table with the new table. Useful when reusing
        fields across queries.

        :param current_table:
            The table to be replaced.
        :param new_table:
            The table to replace with.
        :return:
            A copy of the join with the tables replaced.
        """
        self.item = new_table if self.item == current_table else self.item
        self.fields = [
            field.replace_table(current_table, new_table) for field in self.fields
        ] 
Example #23
Source File: GameOfLife.py    From BiblioPixelAnimations with MIT License 6 votes vote down vote up
def create_time_table(self, t):
        t = time.localtime(t)
        hr = t.tm_hour
        if not self.mil_time:
            hr = hr % 12
        hrs = str(hr).zfill(2)
        mins = str(t.tm_min).zfill(2)
        val = hrs + ":" + mins
        w, h = font.str_dim(val, font=self.font_name,
                            font_scale=self.scale, final_sep=False)
        x = (self.width - w) // 2
        y = (self.height - h) // 2
        old_buf = copy.copy(self.layout.colors)
        self.layout.all_off()
        self.layout.drawText(val, x, y, color=COLORS.Red,
                             font=self.font_name, font_scale=self.scale)
        table = []
        for y in range(self.height):
            table.append([0] * self.width)
            for x in range(self.width):
                table[y][x] = int(any(self.layout.get(x, y)))
        self.layout.setBuffer(old_buf)
        return table 
Example #24
Source File: test_symbol.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def test_symbol_copy():
    data = mx.symbol.Variable('data')
    data_2 = copy.deepcopy(data)
    data_3 = copy.copy(data)
    assert data.tojson() == data_2.tojson()
    assert data.tojson() == data_3.tojson() 
Example #25
Source File: lstm.py    From fine-lm with MIT License 5 votes vote down vote up
def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train):
  """The basic LSTM seq2seq model with bidirectional encoder."""
  with tf.variable_scope("lstm_seq2seq_bid_encoder"):
    if inputs is not None:
      inputs_length = common_layers.length_from_embedding(inputs)
      # Flatten inputs.
      inputs = common_layers.flatten4d3d(inputs)
      # LSTM encoder.
      _, final_encoder_state = lstm_bid_encoder(
          inputs, inputs_length, hparams, train, "encoder")
    else:
      inputs_length = None
      final_encoder_state = None
    # LSTM decoder.
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = common_layers.length_from_embedding(shifted_targets) + 1
    hparams_decoder = copy.copy(hparams)
    hparams_decoder.hidden_size = 2 * hparams.hidden_size
    decoder_outputs, _ = lstm(
        common_layers.flatten4d3d(shifted_targets),
        targets_length,
        hparams_decoder,
        train,
        "decoder",
        initial_state=final_encoder_state)
    return tf.expand_dims(decoder_outputs, axis=2) 
Example #26
Source File: slice.py    From tfont with Apache License 2.0 5 votes vote down vote up
def makePath(endSegment, segmentsMap, path=None, targetSegment=None):
    if path is None:
        path = Path()
    if targetSegment is None:
        targetSegment = endSegment
    iterator = segmentsMap.pop(targetSegment)
    points = path._points
    point = copy(next(iterator).onCurve)
    point.smooth = False
    point.type = "line"
    point._parent = path
    points.append(point)
    for segment in iterator:
        isJump = segment in segmentsMap
        isLast = segment is endSegment
        for point in segment.penPoints:
            # original segment will be trashed so we only need to
            # copy the overlapping section
            if isJump or isLast and point.type is not None:
                point = copy(point)
                point.smooth = False
            point._parent = path
            points.append(point)
        if isLast:
            break
        elif isJump:
            makePath(endSegment, segmentsMap, path, segment)
            break
    return path 
Example #27
Source File: bleu_scorer.py    From deep-summarization with MIT License 5 votes vote down vote up
def copy(self):
        """
        copy the refs.
        :return:
        """
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new 
Example #28
Source File: base.py    From Paradrop with Apache License 2.0 5 votes vote down vote up
def copy(self):
        """
        Make a copy of the config object.

        The copy will receive the same name and option values.
        """
        other = self.__class__()

        other.source = self.source
        other.name = self.name
        other.comment = self.comment

        other.parents = self.parents.copy()
        other.dependents = self.dependents.copy()

        for option in self.options:
            # We use copy here because it works with both the str- and
            # list-typed values.  Any lists are lists of strings, so
            # shallow-copy here is fine.
            copied = copy.copy(getattr(self, option.name))
            setattr(other, option.name, copied)

        # We should call setup on a new config object after all of the option
        # values are filled in.
        other.setup()

        return other 
Example #29
Source File: model_rl_experiment.py    From fine-lm with MIT License 5 votes vote down vote up
def train_agent(problem_name, agent_model_dir,
                event_dir, world_model_dir, epoch_data_dir, hparams, epoch=0):
  """Train the PPO agent in the simulated environment."""
  gym_problem = registry.problem(problem_name)
  ppo_hparams = trainer_lib.create_hparams(hparams.ppo_params)
  ppo_params_names = ["epochs_num", "epoch_length",
                      "learning_rate", "num_agents",
                      "optimization_epochs"]

  for param_name in ppo_params_names:
    ppo_param_name = "ppo_"+ param_name
    if ppo_param_name in hparams:
      ppo_hparams.set_hparam(param_name, hparams.get(ppo_param_name))

  ppo_epochs_num = hparams.ppo_epochs_num
  ppo_hparams.save_models_every_epochs = ppo_epochs_num
  ppo_hparams.world_model_dir = world_model_dir
  ppo_hparams.add_hparam("force_beginning_resets", True)

  # Adding model hparams for model specific adjustments
  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  ppo_hparams.add_hparam("model_hparams", model_hparams)

  environment_spec = copy.copy(gym_problem.environment_spec)
  environment_spec.simulation_random_starts = hparams.simulation_random_starts
  environment_spec.intrinsic_reward_scale = hparams.intrinsic_reward_scale

  ppo_hparams.add_hparam("environment_spec", environment_spec)

  with temporary_flags({
      "problem": problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "output_dir": world_model_dir,
      "data_dir": epoch_data_dir,
  }):
    rl_trainer_lib.train(ppo_hparams, event_dir, agent_model_dir, epoch=epoch) 
Example #30
Source File: t2t_model.py    From fine-lm with MIT License 5 votes vote down vote up
def set_mode(self, mode):
    """Set hparams with the given mode."""
    log_info("Setting T2TModel mode to '%s'", mode)
    hparams = copy.copy(self._original_hparams)
    hparams.add_hparam("mode", mode)
    # When not in training mode, set all forms of dropout to zero.
    if mode != tf.estimator.ModeKeys.TRAIN:
      for key in hparams.values():
        if key.endswith("dropout"):
          log_info("Setting hparams.%s to 0.0", key)
          setattr(hparams, key, 0.0)
    self._hparams = hparams