import unittest
import exp_replay
from exp_replay import Step
import numpy as np


class ExpReplayTest(unittest.TestCase):
  """
  Unit test for ExpReplay class
  """


  def test1(self):
    exprep = exp_replay.ExpReplay(mem_size=100, state_size=[1], kth=1)
    for i in xrange(120):
      exprep.add_step(Step(cur_step=i, action=0, next_step=i+1, reward=0, done=False))
    self.assertEqual(len(exprep.mem), 100)
    self.assertEqual(exprep.mem[-1:][0].cur_step, 119)


  def test2(self):
    exprep = exp_replay.ExpReplay(mem_size=100, state_size=[1], kth=4)
    for i in xrange(120):
      exprep.add_step(Step(cur_step=i, action=0, next_step=i+1, reward=0, done=False))
    self.assertEqual(len(exprep.mem), 100)
    self.assertEqual(exprep.mem[-1:][0].cur_step, 119)
    self.assertEqual(exprep.get_last_state(), [116,117,118,119])


  def test3(self):
    exprep = exp_replay.ExpReplay(mem_size=100, state_size=[2,2], kth=4)
    for i in xrange(120):
      exprep.add_step(Step(cur_step=[[i,i],[i,i]], action=0, next_step=[[i+1,i+1],[i+1,i+1]], reward=0, done=False))
    self.assertEqual(len(exprep.mem), 100)
    self.assertEqual(exprep.mem[-1:][0].cur_step, [[119,119],[119,119]])
    last_state = exprep.get_last_state()

    self.assertEqual(np.shape(last_state),(2,2,4))
    self.assertTrue(np.array_equal(last_state[:,:,0], [[116,116],[116,116]]))
    self.assertTrue(np.array_equal(last_state[:,:,1], [[117,117],[117,117]]))
    self.assertTrue(np.array_equal(last_state[:,:,2], [[118,118],[118,118]]))
    self.assertTrue(np.array_equal(last_state[:,:,3], [[119,119],[119,119]]))

    sample = exprep.sample(5)
    self.assertEqual(len(sample), 5)
    self.assertEqual(np.shape(sample[0].cur_step), (2,2,4))
    self.assertEqual(np.shape(sample[0].next_step), (2,2,4))


  def test4(self):
    # -1 for sending raw state
    exprep = exp_replay.ExpReplay(mem_size=100, state_size=[4], kth=-1)
    for i in xrange(120):
      exprep.add_step(Step(cur_step=[i,i,i,i], action=0, next_step=[i+1,i+1,i+1,i+1], reward=0, done=False))
    last_state = exprep.get_last_state()
    self.assertEqual(np.shape(last_state),(4,))
    self.assertTrue(np.array_equal(last_state, [119,119,119,119]))

    sample = exprep.sample(5)
    self.assertEqual(len(sample), 5)
    self.assertEqual(np.shape(sample[0].cur_step), (4,))

if __name__ == '__main__':
  unittest.main()