# ''' May,08, 2014 @author Hideki Ikeda Unit test for the StockValue class in stock_value.py ''' import os import sys from datetime import date, datetime, timedelta import shutil import unittest from sklearn import linear_model sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) from chelmbigstock import stock_value class TestStockValue(unittest.TestCase): """ Unit tests for the StockValue class StockValue uses the StockHist and LinearAdapter classes. Make sure these classes work before testing StockValue """ _test_data = [ # [ symbol, start_date, end_date ] [ 'HPQ', date(1993, 6, 14), date(1993, 7, 19) ], # small data [ 'MSFT', None, None ] # large data ] @classmethod def setUpClass(cls): cls._my_path = os.path.abspath(os.path.dirname(__file__)) cls._result_path = os.path.join(cls._my_path, 'result') cls._expected_path = os.path.join(cls._my_path, 'expected') os.mkdir(cls._result_path) stock_value.StockHist.default_path = cls._result_path cls._predictor = stock_value.LinearAdapter(linear_model.Ridge(alpha=0.1, fit_intercept=False)) cls._comment = 'Ridge alpha=0.1' cls._histories = [] for symbol, sdate, edate in cls._test_data: cls._histories.append(stock_value.StockHist(symbol, sdate, edate)) @classmethod def tearDownClass(cls): shutil.rmtree(cls._result_path) def test_past(self): for hist in self._histories: stock = stock_value.StockValue(hist, self._predictor, self._comment) # test dates expected = hist.dates result = stock.past_dates self.assertEqual(expected, result) # test highs expected = hist.highs result = stock.past_highs self.assertEqual(expected, result) def test_comment(self): stock = stock_value.StockValue(self._histories[0], self._predictor, self._comment) self.assertEqual(self._comment, stock.comment) def test_future(self): future_dates = [ date.today() + timedelta(days=1), date.today() + timedelta(days=30) ] for hist in self._histories: stock = stock_value.StockValue(hist, self._predictor, self._comment) result = stock.future_highs(future_dates) self.assertEqual(len(future_dates), len(result)) if __name__ == "__main__": unittest.main()