# Copyright (c) 2018, The SenseAct Authors.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
from multiprocessing import Array

from senseact.utils import get_random_state_array, get_random_state_from_array


class TestUtils(unittest.TestCase):

    def test_random_state_array(self):
        rand_obj = np.random.RandomState(1)
        rand_state = rand_obj.get_state()
        original_uniform_values = rand_obj.uniform(-1, 1, 100)
        original_normal_values = rand_obj.randn(100)

        rand_state_array_type, rand_state_array_size, rand_state_array = get_random_state_array(rand_state)
        shared_rand_array = np.frombuffer(Array('b', rand_state_array_size).get_obj(), dtype=rand_state_array_type)
        np.copyto(shared_rand_array, np.frombuffer(rand_state_array, dtype=rand_state_array_type))

        new_rand_obj = np.random.RandomState()
        new_rand_obj.set_state(get_random_state_from_array(shared_rand_array))
        new_uniform_values = new_rand_obj.uniform(-1, 1, 100)
        new_normal_values = new_rand_obj.randn(100)

        assert np.all(original_uniform_values == new_uniform_values)
        assert np.all(original_normal_values == new_normal_values)


if __name__ == '__main__':
    unittest.main(buffer=True)