# -*- coding: utf-8 -*- # SPDX-License-Identifier: Apache-2.0 from typing import Dict, List, NoReturn, Optional, Callable import numpy as np from mabwiser.base_mab import BaseMAB from mabwiser.utils import Arm, Num, reset, argmax, _BaseRNG class _ThompsonSampling(BaseMAB): def __init__(self, rng: _BaseRNG, arms: List[Arm], n_jobs: int, backend: Optional[str], binarizer: Optional[Callable] = None): super().__init__(rng, arms, n_jobs, backend) self.binarizer = binarizer # Track whether the rewards have been binarized already by a context policy external self.is_contextual_binarized = False self.arm_to_success_count = dict.fromkeys(self.arms, 1) self.arm_to_fail_count = dict.fromkeys(self.arms, 1) def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = None) -> NoReturn: # If rewards are non binary, convert them rewards = self._get_binary_rewards(decisions, rewards) # Reset the success and failure counters to 1 (beta distribution is undefined for 0) reset(self.arm_to_success_count, 1) reset(self.arm_to_fail_count, 1) # Calculate fit self._parallel_fit(decisions, rewards) # Leave the calculation of expectations to predict methods def partial_fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: Optional[np.ndarray] = None) -> NoReturn: # If rewards are non binary, convert them rewards = self._get_binary_rewards(decisions, rewards) # Calculate fit self._parallel_fit(decisions, rewards) def predict(self, contexts: np.ndarray = None) -> Arm: # Return the arm with maximum expectation. If multiple max value exists, return the first one return argmax(self.predict_expectations()) def predict_expectations(self, contexts: np.ndarray = None) -> Dict[Arm, Num]: # Expectation of each arm is a random sample from beta distribution with success and fail counters for arm in self.arm_to_expectation: self.arm_to_expectation[arm] = self.rng.beta(self.arm_to_success_count[arm], self.arm_to_fail_count[arm]) # Return a copy of expectations dictionary from arms (key) to expectations (values) return self.arm_to_expectation.copy() def _fit_arm(self, arm: Arm, decisions: np.ndarray, rewards: np.ndarray, contexts: Optional[np.ndarray] = None): arm_rewards = rewards[decisions == arm] count_of_ones = arm_rewards.sum() self.arm_to_success_count[arm] += count_of_ones self.arm_to_fail_count[arm] += len(arm_rewards) - count_of_ones def _predict_contexts(self, contexts: np.ndarray, is_predict: bool, seeds: Optional[np.ndarray] = None, start_index: Optional[int] = None) -> List: pass def _get_binary_rewards(self, decisions: np.ndarray, rewards: np.ndarray): # If a binarizer function is given and binarization has not taken place already in a neighborhood policy if self.binarizer and not self.is_contextual_binarized: return np.fromiter((self.binarizer(decisions[index], value) # convert every decision-reward pair to binary for index, value in enumerate(rewards)), rewards.dtype) else: return rewards def _uptake_new_arm(self, arm: Arm, binarizer: Callable = None, scaler: Callable = None): # Don't override the existing binarizer unless a new one is given if binarizer: self.binarizer = binarizer self.arm_to_success_count[arm] = 1 self.arm_to_fail_count[arm] = 1