"""Unit test for LinUCB
"""
import unittest

import numpy as np

from striatum.bandit import LinUCB
from striatum.storage import (
    MemoryHistoryStorage,
    MemoryModelStorage,
    MemoryActionStorage,
    Action,
)
from .base_bandit_test import BaseBanditTest, ChangeableActionSetBanditTest


class TestLinUCB(ChangeableActionSetBanditTest,
                 BaseBanditTest,
                 unittest.TestCase):
    # pylint: disable=protected-access

    def setUp(self):
        super(TestLinUCB, self).setUp()
        self.context_dimension = 2
        self.alpha = 1.
        self.policy = LinUCB(
            self.history_storage, self.model_storage,
            self.action_storage, context_dimension=self.context_dimension,
            alpha=self.alpha)
        self.policy_with_empty_action_storage = LinUCB(
            MemoryHistoryStorage(), MemoryModelStorage(), MemoryActionStorage(),
            context_dimension=self.context_dimension, alpha=self.alpha)

    def test_initialization(self):
        super(TestLinUCB, self).test_initialization()
        policy = self.policy
        self.assertEqual(self.context_dimension, policy.context_dimension)
        self.assertEqual(self.alpha, policy.alpha)

    def test_model_storage(self):
        model = self.policy._model_storage.get_model()
        self.assertEqual(len(model['b']), self.action_storage.count())
        self.assertEqual(len(model['b'][1]), self.context_dimension)
        self.assertEqual(len(model['A']), self.action_storage.count())
        self.assertEqual(model['A'][1].shape,
                         (self.context_dimension, self.context_dimension))

    def test_add_action(self):
        policy = self.policy
        context1 = {1: [1, 1], 2: [2, 2], 3: [3, 3]}
        history_id, _ = policy.get_action(context1, 2)
        new_actions = [Action() for i in range(2)]
        policy.add_action(new_actions)
        self.assertEqual(len(new_actions) + len(self.actions),
                         policy._action_storage.count())
        policy.reward(history_id, {3: 1})
        model = policy._model_storage.get_model()
        for action in new_actions:
            self.assertTrue((model['A'][action.id]
                             == np.identity(self.context_dimension)).all())

        context2 = {1: [1, 1], 2: [2, 2], 3: [3, 3], 4: [4, 4], 5: [5, 5]}
        history_id2, recommendations = policy.get_action(context2, 4)
        self.assertEqual(len(recommendations), 4)
        policy.reward(history_id2, {new_actions[0].id: 4, new_actions[1].id: 5})
        model = policy._model_storage.get_model()
        for action in new_actions:
            self.assertFalse((model['A'][action.id] == np.identity(2)).all())