from six.moves import range import matplotlib.pyplot as plt from striatum.storage import ( MemoryHistoryStorage, MemoryModelStorage, MemoryActionStorage, Action, ) from striatum.bandit import UCB1 from striatum import simulation def main(): context_dimension = 5 action_storage = MemoryActionStorage() action_storage.add([Action(i) for i in range(5)]) # Regret Analysis n_rounds = 10000 context, desired_actions = simulation.simulate_data( n_rounds, context_dimension, action_storage, random_state=1) policy = UCB1(MemoryHistoryStorage(), MemoryModelStorage(), action_storage) for t in range(n_rounds): history_id, recommendation = policy.get_action(context[t]) action_id = recommendation.action.id if desired_actions[t] != action_id: policy.reward(history_id, {action_id: 0}) else: policy.reward(history_id, {action_id: 1}) policy.plot_avg_regret() plt.show() if __name__ == '__main__': main()