Python math.inf() Examples

The following are 30 code examples for showing how to use math.inf(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module math , or try the search function .

Example 1
Project: interpret-text   Author: interpretml   File: text_explainer_utils.py    License: MIT License 6 votes vote down vote up
def _find_golden_doc(function, evaluation_examples):
    highest_prob_value = -math.inf
    highest_prob_index = -1
    # Find example with highest predicted prob in classification case
    # or highest prediction in regression case
    for index, row in enumerate(evaluation_examples):
        rowArr = [row]
        prediction = function(rowArr)
        if len(prediction.shape) == 2:
            prediction = prediction[0]
        # TODO: Change this to calculate multiple pred_max for each class prediction
        pred_max = max(prediction)
        if pred_max > highest_prob_value:
            highest_prob_value = pred_max
            highest_prob_index = index
    return evaluation_examples[highest_prob_index] 
Example 2
Project: Gemini   Author: anfederico   File: exchange.py    License: GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, no, entry_price, shares, exit_price=math.inf, stop_loss=0):
        """Open the position.

        :param no: A unique position id number
        :type no: float
        :param entry_price: Entry price at which shares are longed
        :type entry_price: float
        :param shares: Number of shares to long
        :type shares: float
        :param exit_price: Price at which to take profit
        :type exit_price: float
        :param stop_loss: Price at which to cut losses
        :type stop_loss: float

        :return: A long position
        :rtype: long_position
        """

        if exit_price is False: exit_price = math.inf
        if stop_loss is False: stop_loss = 0
        super().__init__(no, entry_price, shares, exit_price, stop_loss)
        self.type = 'long' 
Example 3
Project: Gemini   Author: anfederico   File: exchange.py    License: GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, no, entry_price, shares, exit_price=0, stop_loss=math.inf):
        """Open the position.

        :param no: A unique position id number
        :type no: int
        :param entry_price: Entry price at which shares are shorted
        :type entry_price: float
        :param shares: Number of shares to short
        :type shares: float
        :param exit_price: Price at which to take profit
        :type exit_price: float
        :param stop_loss: Price at which to cut losses
        :type stop_loss: float

        :return: A short position
        :rtype: short_position
        """       
        if exit_price is False: exit_price = 0
        if stop_loss is False: stop_loss = math.inf
        super().__init__(no, entry_price, shares, exit_price, stop_loss)
        self.type = 'short' 
Example 4
Project: resolwe   Author: genialis   File: listener.py    License: Apache License 2.0 6 votes vote down vote up
def check_critical_load(self):
        """Check for critical load and log an error if necessary."""
        if self.load_avg.intervals["1m"].value > 1:
            if self.last_load_level == 1 and time.time() - self.last_load_log < 30:
                return
            self.last_load_log = time.time()
            self.last_load_level = 1
            logger.error(
                "Listener load limit exceeded, the system can't handle this!",
                extra=self._make_stats(),
            )

        elif self.load_avg.intervals["1m"].value > 0.8:
            if self.last_load_level == 0.8 and time.time() - self.last_load_log < 30:
                return
            self.last_load_log = time.time()
            self.last_load_level = 0.8
            logger.warning(
                "Listener load approaching critical!", extra=self._make_stats()
            )

        else:
            self.last_load_log = -math.inf
            self.last_load_level = 0 
Example 5
Project: video_captioning_rl   Author: ramakanth-pasunuru   File: esim.py    License: MIT License 6 votes vote down vote up
def similarity(self, s1, l1, s2, l2):
        """
        :param s1: [B, t1, D]
        :param l1: [B]
        :param s2: [B, t2, D]
        :param l2: [B]
        :return:
        """
        batch_size = s1.size(0)
        t1 = s1.size(1)
        t2 = s2.size(1)
        S = torch.bmm(s1, s2.transpose(1,
                                       2))  # [B, t1, D] * [B, D, t2] -> [B, t1, t2] S is the similarity matrix from biDAF paper. [B, T1, T2]

        s_mask = S.data.new(*S.size()).fill_(1).byte()  # [B, T1, T2]
        # Init similarity mask using lengths
        for i, (l_1, l_2) in enumerate(zip(l1, l2)):
            s_mask[i][:l_1, :l_2] = 0

        s_mask = Variable(s_mask)
        S.data.masked_fill_(s_mask.data.byte(), -math.inf)
        return S 
Example 6
Project: translate   Author: pytorch   File: beam_decode.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _add_to_end_states(
        self, end_states: List[Tensor], min_score: float, state: Tensor, min_index: int
    ) -> Tuple[List[Tensor], float, int]:
        """
        Maintains a list of atmost `nbest` highest end states
        """
        if len(end_states) < self.nbest:
            end_states.append(state)
            # keep min_score and min_index updated
            if float(state[0]) <= min_score:
                min_score = float(state[0])
                min_index = len(end_states) - 1
        elif bool(state[0] > min_score):
            # replace worst hypo with the new one
            end_states[min_index] = state
            # find new worst hypo, keep min_score and min_index updated
            min_index = -1
            min_score = float("inf")
            for idx in range(len(end_states)):
                s = end_states[idx]
                if bool(float(s[0]) <= min_score):
                    min_index = idx
                    min_score = float(s[0])
        return end_states, min_score, min_index 
Example 7
def transform(self, X: dt.Frame):
        X.replace([None, math.inf, -math.inf], self._repl_val)
        from flair.embeddings import WordEmbeddings, BertEmbeddings, DocumentPoolEmbeddings, Sentence
        if self.embedding_name in ["glove", "en"]:
            self.embedding = WordEmbeddings(self.embedding_name)
        elif self.embedding_name in ["bert"]:
            self.embedding = BertEmbeddings()
        self.doc_embedding = DocumentPoolEmbeddings([self.embedding])
        output = []
        X = X.to_pandas()
        text1_arr = X.iloc[:, 0].values
        text2_arr = X.iloc[:, 1].values
        for ind, text1 in enumerate(text1_arr):
            try:
                text1 = Sentence(str(text1).lower())
                self.doc_embedding.embed(text1)
                text2 = text2_arr[ind]
                text2 = Sentence(str(text2).lower())
                self.doc_embedding.embed(text2)
                score = cosine_similarity(text1.get_embedding().reshape(1, -1),
                                          text2.get_embedding().reshape(1, -1))[0, 0]
                output.append(score)
            except:
                output.append(-99)
        return np.array(output) 
Example 8
Project: SFA_Python   Author: sharford5   File: SFASupervised.py    License: GNU General Public License v3.0 6 votes vote down vote up
def fitTransformed(self, samples, wordLength, symbols, normMean):
        length = len(samples[0].data)
        transformedSignal = self.sfa.fitTransformDouble(samples, length, symbols, normMean)

        best = self.calcBestCoefficients(samples, transformedSignal)
        self.bestValues = [0 for i in range(min(len(best), wordLength))]
        self.maxWordLength = 0

        for i in range(len(self.bestValues)):
            if best[i][1] != -math.inf:
                self.bestValues[i] = best[i][0]
                self.maxWordLength = max(best[i][0] + 1, self.maxWordLength)

        self.maxWordLength += self.maxWordLength % 2
        self.sfa.maxWordLength = self.maxWordLength
        return self.sfa.transform(samples, transformedSignal) 
Example 9
Project: Python   Author: TheAlgorithms   File: lazy_segment_tree.py    License: MIT License 6 votes vote down vote up
def query(self, idx, l, r, a, b):  # noqa: E741
        """
        query(1, 1, N, a, b) for query max of [a,b]
        """
        if self.flag[idx] is True:
            self.st[idx] = self.lazy[idx]
            self.flag[idx] = False
            if l != r:  # noqa: E741
                self.lazy[self.left(idx)] = self.lazy[idx]
                self.lazy[self.right(idx)] = self.lazy[idx]
                self.flag[self.left(idx)] = True
                self.flag[self.right(idx)] = True
        if r < a or l > b:
            return -math.inf
        if l >= a and r <= b:  # noqa: E741
            return self.st[idx]
        mid = (l + r) // 2
        q1 = self.query(self.left(idx), l, mid, a, b)
        q2 = self.query(self.right(idx), mid + 1, r, a, b)
        return max(q1, q2) 
Example 10
Project: ReGraph   Author: Kappa-Dev   File: attribute_sets.py    License: MIT License 6 votes vote down vote up
def from_json(cls, json_data):
        """Create attribute set object from json-like dictionary."""
        if "type" in json_data.keys():
            init_args = None
            if "data" in json_data.keys():
                if not (len(json_data["data"]) == 1 and
                        json_data["data"][0] is None):
                    init_args = json_data["data"]

            # JSON cannot dump tuples, so finite set of tuples is usually
            # represented as a list of lists, if we read from json list of
            # lists, we interpret them as a set of tuples
            if json_data["type"] == "FiniteSet" and init_args is not None:
                for i, element in enumerate(init_args):
                    if type(element) == list:
                        init_args[i] = tuple(element)
            if json_data["type"] == "IntegerSet" and init_args is not None:
                for i, element in enumerate(init_args):
                    if element[0] == "-inf":
                        init_args[i][0] = -math.inf
                    if element[1] == "inf":
                        init_args[i][1] = math.inf

            return getattr(sys.modules[__name__], json_data["type"])(init_args) 
Example 11
Project: ReGraph   Author: Kappa-Dev   File: attribute_sets.py    License: MIT License 6 votes vote down vote up
def __str__(self):
        """String representation of IntegerSet obj."""
        interval_strs = []
        for start, end in self.intervals:
            if start > -math.inf:
                start_str = "%d" % start
            else:
                start_str = "-inf"
            if end < math.inf:
                end_str = "%d" % end
            else:
                end_str = "inf"
            if start_str != end_str:
                interval_strs.append("[" + start_str + ", " + end_str + "]")
            else:
                interval_strs.append("{" + start_str + "}")
        return ", ".join(interval_strs) 
Example 12
Project: ReGraph   Author: Kappa-Dev   File: attribute_sets.py    License: MIT License 6 votes vote down vote up
def to_json(self):
        """JSON represenation of IntegerSet."""
        json_data = {}
        json_data["type"] = "IntegerSet"
        json_data["data"] = []
        for start, end in self.intervals:
            if math.isinf(-start):
                new_start = "-inf"
            else:
                new_start = start
            if math.isinf(end):
                new_end = "inf"
            else:
                new_end = end
        json_data["data"].append([new_start, new_end])

        return json_data 
Example 13
Project: FATE   Author: FederatedAI   File: heap.py    License: Apache License 2.0 6 votes vote down vote up
def cal_score(self):
        """
        gini = 1 - ∑(p_i^2 ) = 1 -(event / total)^2 - (nonevent / total)^2
        """

        self.event_count = self.left_bucket.event_count + self.right_bucket.event_count
        self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count
        if self.total_count == 0:
            self.score = -math.inf
            return

        # if self.total_count == 0 or self.left_bucket.left_bound == self.right_bucket.right_bound:
        #     self.score = -math.inf
        #     return
        merged_gini = 1 - (1.0 * self.event_count / self.total_count) ** 2 - \
                      (1.0 * self.non_event_count / self.total_count) ** 2
        self.score = merged_gini - self.left_bucket.gini - self.right_bucket.gini 
Example 14
Project: matchpy   Author: HPAC   File: many_to_one.py    License: MIT License 5 votes vote down vote up
def replace(self, expression: Expression, max_count: int=math.inf) -> Union[Expression, Sequence[Expression]]:
        """Replace all occurrences of the patterns according to the replacement rules.

        Args:
            expression:
                The expression to which the replacement rules are applied.
            max_count:
                If given, at most *max_count* applications of the rules are performed. Otherwise, the rules
                are applied until there is no more match. If the set of replacement rules is not confluent,
                the replacement might not terminate without a *max_count* set.

        Returns:
            The resulting expression after the application of the replacement rules. This can also be a sequence of
            expressions, if the root expression is replaced with a sequence of expressions by a rule.
        """
        replaced = True
        replace_count = 0
        while replaced and replace_count < max_count:
            replaced = False
            for subexpr, pos in preorder_iter_with_position(expression):
                try:
                    replacement, subst = next(iter(self.matcher.match(subexpr)))
                    result = replacement(**subst)
                    expression = functions.replace(expression, pos, result)
                    replaced = True
                    break
                except StopIteration:
                    pass
            replace_count += 1
        return expression 
Example 15
Project: Advanced-Data-Structures-with-Python   Author: bhavinjawade   File: segment_Tree.py    License: MIT License 5 votes vote down vote up
def query(st,ql,qh,low,high,pos):
  if(ql<=low and qh>=high):
    return st[pos]
  if(ql > high or qh < low):
    return math.inf
  mid = (low + high)/2
  return min(query(st,ql,qh,low,mid,2*pos + 1),
  query(st,ql,qh,mid+1,high,2*pos + 2)) 
Example 16
Project: FINE   Author: FZJ-IEK3-VSA   File: robustPipelineSizing.py    License: MIT License 5 votes vote down vote up
def _postprocessing(scenario, dic_scenario_flows, graph, **kwargs):
    dic_scen_PressLevel = {}
    dic_scen_MaxViolPress = math.inf
    # copy a list of nodes
    tmp_nodes = copy.deepcopy(list(graph.nodes))
    # we now set iteratively the pressure level of a single node to its upper pressure bound and then compute the
    # unique pressure levels until we find valid pressure levels or have tested all nodes
    while tmp_nodes:
        # we have not found valid pressure levels for this scenario
        # temporary pressure levels
        dic_tmp_pressure = {}
        for node in list(graph.nodes):
            dic_tmp_pressure[node] = None
        # choose the node which pressure level is fixed to the upper pressure bound
        current_node = tmp_nodes[0]
        validation, tmp_viol = computePressureAtNode(graph=graph, node=current_node, nodeUpperBound=current_node,
            dic_scenario_flows=dic_scenario_flows[scenario], dic_node_pressure=dic_tmp_pressure, **kwargs)
        # if validation true, then we have feasible pressure levels; empty list of nodes that have to be
        # considered
        if validation:
            tmp_nodes = []
            # we have feasible pressure level and save them
            dic_scen_PressLevel = dic_tmp_pressure
            dic_scen_MaxViolPress = tmp_viol
        else:
            # remove considered entry from list of nodes that will be considered for fixing the pressure level
            tmp_nodes.remove(tmp_nodes[0])
            # we update the maximal pressure level violation
            if tmp_viol < dic_scen_MaxViolPress:
                # save currently best pressure levels
                dic_scen_PressLevel = copy.deepcopy(dic_tmp_pressure)
                dic_scen_MaxViolPress = tmp_viol

    return scenario, dic_scen_PressLevel, dic_scen_MaxViolPress 
Example 17
Project: ConvLab   Author: ConvLab   File: dataset_reader.py    License: MIT License 5 votes vote down vote up
def find_best_delex_act(self, action):
        def _score(a1, a2):
            score = 0
            for domain_act in a1:
                if domain_act not in a2:
                    score += len(a1[domain_act])
                else:
                    score += len(set(a1[domain_act]) - set(a2[domain_act]))
            return score

        best_p_action_index = -1
        best_p_score = math.inf
        best_pn_action_index = -1
        best_pn_score = math.inf
        for i, v_action in enumerate(self.action_list):
            if v_action == action:
                return i
            else:
                p_score = _score(action, v_action)
                n_score = _score(v_action, action)
                if p_score > 0 and n_score == 0 and p_score < best_p_score:
                    best_p_action_index = i
                    best_p_score = p_score
                else:
                    if p_score + n_score < best_pn_score:
                        best_pn_action_index = i
                        best_pn_score = p_score + n_score
        if best_p_action_index >= 0:
            return best_p_action_index
        return best_pn_action_index 
Example 18
Project: Jtyoui   Author: jtyoui   File: theorem.py    License: MIT License 5 votes vote down vote up
def theorem_Zero(function, x1: float, x2: float) -> float:
    """零点定理

    定义一个函数:x^3-2x-5=0,求x等于多少。x的值域:[1,1000]
    原理利用二分法不断的逼近,求出答案

    :param function: 定一个函数
    :param x1: 开始值
    :param x2: 结束值
    :return: 返回零点的值
    """
    if function(x1) == 0:
        return x1
    elif function(x2) == 0:
        return x2
    elif function(x1) * function(x2) > 0:
        warnings.warn('[a,b]区间的值应该满足:f(a)*f(b)<0', category=MathValueWarning)
        return math.inf
    else:
        mid = x1 + (x2 - x1) / 2.0
        while abs(x1 - mid) > math.pow(10, -9):  # x值小于10亿分之一
            if function(mid) == 0:
                return mid
            elif function(mid) * function(x1) < 0:
                x2 = mid
            else:
                x1 = mid
            mid = x1 + (x2 - x1) / 2.0
        return mid 
Example 19
Project: workload-collocation-agent   Author: intel   File: allocations.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, value: Union[float, int],
                 common_labels: Dict[str, str] = None,
                 min_value: Optional[Union[int, float]] = 0,
                 max_value: Optional[Union[int, float]] = None,
                 value_change_sensitivity: float = VALUE_CHANGE_SENSITIVITY,
                 ):
        if not isinstance(value, (float, int)):
            assert isinstance(value, (float, int)), \
                    'should be of type (float, int) but was {}'.format(type(value))
        self.value = value
        self.value_change_sensitivity = value_change_sensitivity
        self.min_value = min_value if min_value is not None else -math.inf
        self.max_value = max_value if max_value is not None else math.inf
        self.labels_updater = LabelsUpdater(common_labels or {}) 
Example 20
Project: resolwe   Author: genialis   File: listener.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, *args, **kwargs):
        """Initialize attributes.

        :param host: Optional. The hostname where redis is running.
        :param port: Optional. The port where redis is running.
        """
        super().__init__()

        # The Redis connection object.
        self._redis = None
        self._redis_params = kwargs.get("redis_params", {})

        # Running coordination.
        self._should_stop = False
        self._runner_coro = None

        # The verbosity level to pass around to Resolwe utilities.
        self._verbosity = kwargs.get("verbosity", 1)

        # Statistics about how much time each event needed for handling.
        self.service_time = stats.NumberSeriesShape()

        # Statistics about the number of events handled per time interval.
        self.load_avg = stats.SimpleLoadAvg([60, 5 * 60, 15 * 60])

        # Timestamp of last critical load error and level, for throttling.
        self.last_load_log = -math.inf
        self.last_load_level = 0 
Example 21
Project: resolwe   Author: genialis   File: stats.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self):
        """Construct an instance of the class."""
        self.high = -math.inf
        self.low = math.inf
        self.mean = 0
        self.deviation = 0
        self.count = 0
        self._rolling_variance = 0 
Example 22
Project: resolwe   Author: genialis   File: stats.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, intervals):
        """Construct an instance of the class.

        :param interval: A list of interval lengths, in seconds.
        """
        self.last_data = -math.inf
        self.intervals = {i: SimpleLoadAvg._Interval(i) for i in intervals}
        for meta in list(self.intervals.values()):
            self.intervals[meta.display] = meta 
Example 23
Project: resolwe   Author: genialis   File: test_stats.py    License: Apache License 2.0 5 votes vote down vote up
def test_shape_basic(self):
        series = stats.NumberSeriesShape()
        self.assertEqual(
            series.to_dict(),
            {
                "high": -math.inf,
                "low": math.inf,
                "mean": 0,
                "count": 0,
                "deviation": 0,
            },
        )

        series.update(1)
        self.assertEqual(series.count, 1)
        self.assertAlmostEqual(series.high, 1.0)
        self.assertAlmostEqual(series.low, 1.0)
        self.assertAlmostEqual(series.mean, 1.0)
        self.assertAlmostEqual(series.deviation, 0.0)

        for _ in range(5):
            series.update(1)
        self.assertEqual(series.count, 6)
        self.assertAlmostEqual(series.high, 1.0)
        self.assertAlmostEqual(series.low, 1.0)
        self.assertAlmostEqual(series.mean, 1.0)
        self.assertAlmostEqual(series.deviation, 0.0)

        large = 1000000.0
        series.update(large)
        series.update(-large)
        self.assertAlmostEqual(series.high, large)
        self.assertAlmostEqual(series.low, -large)
        self.assertAlmostEqual(series.mean, 0.75)
        self.assertAlmostEqual(series.deviation, 534522.483825049) 
Example 24
Project: video_captioning_rl   Author: ramakanth-pasunuru   File: esim.py    License: MIT License 5 votes vote down vote up
def get_U_tile(self, S, s2):
        a_weight = F.softmax(S, dim=2)  # [B, t1, t2]
        a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)  # remove nan from softmax on -inf
        U_tile = torch.bmm(a_weight, s2)  # [B, t1, t2] * [B, t2, D] -> [B, t1, D]
        return U_tile 
Example 25
Project: video_captioning_rl   Author: ramakanth-pasunuru   File: esim.py    License: MIT License 5 votes vote down vote up
def get_both_tile(self, S, s1, s2):
        a_weight = F.softmax(S, dim=2)  # [B, t1, t2]
        a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)  # remove nan from softmax on -inf
        U_tile = torch.bmm(a_weight, s2)  # [B, t1, t2] * [B, t2, D] -> [B, t1, D]

        a1_weight = F.softmax(S, dim=1)  # [B, t1, t2]
        a1_weight.data.masked_fill_(a1_weight.data != a1_weight.data, 0)  # remove nan from softmax on -inf
        U1_tile = torch.bmm(a1_weight.transpose(1, 2), s1)  # [B, t2, t1] * [B, t1, D] -> [B, t2, D]
        return U_tile, U1_tile 
Example 26
Project: crosentgec   Author: nusnlp   File: train.py    License: GNU General Public License v3.0 5 votes vote down vote up
def get_perplexity(loss):
    try:
        return '{:.2f}'.format(math.pow(2, loss))
    except OverflowError:
        return float('inf') 
Example 27
Project: argus-tgs-salt   Author: lRomul   File: lr_scheduler.py    License: MIT License 5 votes vote down vote up
def __init__(self,
                 monitor='val_loss',
                 factor=0.1,
                 patience=1,
                 min_lr=1e-6,
                 better='auto'):
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.better = better

        if self.better == 'auto':
            if monitor.startswith('val_'):
                metric_name = self.monitor[len('val_'):]
            else:
                metric_name = self.monitor[len('train_'):]
            if metric_name not in METRIC_REGISTRY:
                raise ImportError(f"Metric '{metric_name}' not found in scope")
            self.better = METRIC_REGISTRY[metric_name].better
        assert self.better in ['min', 'max', 'auto'], \
            f"Unknown better option '{self.better}'"

        if self.better == 'min':
            self.better_comp = lambda a, b: a < b
            self.best_value = math.inf
        elif self.better == 'max':
            self.better_comp = lambda a, b: a > b
            self.best_value = -math.inf

        self.wait = 0 
Example 28
Project: argus-tgs-salt   Author: lRomul   File: lr_scheduler.py    License: MIT License 5 votes vote down vote up
def start(self, state: State):
        self.wait = 0
        self.best_value = math.inf if self.better == 'min' else -math.inf 
Example 29
Project: PlaNet   Author: Kaixhin   File: main.py    License: MIT License 5 votes vote down vote up
def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, min_action=-inf, max_action=inf, explore=False):
  # Infer belief over current state q(s_t|o≤t,a<t) from the history
  belief, _, _, _, posterior_state, _, _ = transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoder(observation).unsqueeze(dim=0))  # Action and observation need extra time dimension
  belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
  action = planner(belief, posterior_state)  # Get action from planner(q(s_t|o≤t,a<t), p)
  if explore:
    action = action + args.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
  actions.clamp_(min=min_action, max=max_action)  # Clip action range
  next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu())  # Perform environment step (action repeats handled internally)
  return belief, posterior_state, action, next_observation, reward, done


# Testing only 
Example 30
Project: PlaNet   Author: Kaixhin   File: planner.py    License: MIT License 5 votes vote down vote up
def __init__(self, action_size, planning_horizon, optimisation_iters, candidates, top_candidates, transition_model, reward_model, min_action=-inf, max_action=inf):
    super().__init__()
    self.transition_model, self.reward_model = transition_model, reward_model
    self.action_size, self.min_action, self.max_action = action_size, min_action, max_action
    self.planning_horizon = planning_horizon
    self.optimisation_iters = optimisation_iters
    self.candidates, self.top_candidates = candidates, top_candidates