import os import datetime try: import unittest2 as unittest except ImportError: import unittest if os.environ.get('TRAVIS') is None: from db_connector import (DBConnector, GitHubData, PackageManagerData, get_db_connection_string,) from config import Config from github import GitHub from package_managers import PackageManagers from sendgrid_email import SendGrid try: basestring except NameError: basestring = str class TestConfig(unittest.TestCase): def setUp(self): if os.environ.get('TRAVIS') is None: self.config = Config() def test_initialization(self): if os.environ.get('TRAVIS') is None: github_token = os.environ.get('GITHUB_TOKEN') self.assertTrue(isinstance(github_token, basestring)) sendgrid_api_key = os.environ.get('SENDGRID_API_KEY') self.assertTrue(isinstance(sendgrid_api_key, basestring)) mysql_db = os.environ.get('MYSQL_DB_URL') self.assertTrue(isinstance(mysql_db, basestring)) self.assertTrue(isinstance(self.config.github_user, basestring)) self.assertTrue(isinstance(self.config.github_repos, list)) self.assertTrue(isinstance(self.config.package_manager_urls, list)) self.assertTrue(isinstance(self.config.to_email, basestring)) self.assertTrue(isinstance(self.config.from_email, basestring)) self.assertTrue(isinstance(self.config.email_subject, basestring)) self.assertTrue(isinstance(self.config.email_body, basestring)) def test_mysql_db_connection_string(self): if os.environ.get('TRAVIS'): return mysql_str = 'mysql://user:pass@host:port/dbname' connection_string = get_db_connection_string(mysql_str) self.assertEqual(connection_string, 'mysql+pymysql://user:pass@host:port/dbname') def test_sqllite_db_connection_string(self): if os.environ.get('TRAVIS'): return # in memory sqllite = 'sqlite://' connection_string = get_db_connection_string(sqllite) self.assertEqual(connection_string, 'sqlite://') # relative sqllite = 'sqlite:///foo.db' connection_string = get_db_connection_string(sqllite) self.assertEqual(connection_string, 'sqlite:///foo.db') # absolute sqllite = 'sqlite:////foo.db' connection_string = get_db_connection_string(sqllite) self.assertEqual(connection_string, 'sqlite:////foo.db') class TestDBConnector(unittest.TestCase): def setUp(self): if os.environ.get('TRAVIS') is None: self.db = DBConnector() def test_add_and_delete_data(self): if os.environ.get('TRAVIS') is None: github_data_import = GitHubData( date_updated=datetime.datetime.now(), language='repo_name', pull_requests=0, open_issues=0, number_of_commits=0, number_of_branches=0, number_of_releases=0, number_of_contributors=0, number_of_watchers=0, number_of_stargazers=0, number_of_forks=0 ) res = self.db.add_data(github_data_import) self.assertTrue(isinstance(res, GitHubData)) res = self.db.delete_data(res.id, 'github_data') self.assertTrue(res) packagedata = PackageManagerData( date_updated=datetime.datetime.now(), csharp_downloads=0, nodejs_downloads=0, php_downloads=0, python_downloads=0, ruby_downloads=0 ) res = self.db.add_data(packagedata) self.assertTrue(isinstance(res, PackageManagerData)) res = self.db.delete_data(res.id, 'package_manager_data') self.assertTrue(res) def test_get_data(self): if os.environ.get('TRAVIS') is None: github_data = self.db.get_data(GitHubData) self.assertTrue(isinstance(github_data, list)) self.assertTrue(isinstance(github_data[0], GitHubData)) class TestGitHub(unittest.TestCase): def setUp(self): if os.environ.get('TRAVIS') is None: self.github = GitHub() self.db = DBConnector() self.config = Config() def test_update_library_data(self): if os.environ.get('TRAVIS') is None: res = self.github.update_library_data(self.config.github_user, self.config.github_repos[0]) self.assertTrue(isinstance(res, GitHubData)) res = self.db.delete_data(res.id, 'github_data') self.assertTrue(res) class TestPackageManagers(unittest.TestCase): def setUp(self): if os.environ.get('TRAVIS') is None: self.pm = PackageManagers() self.db = DBConnector() self.config = Config() def test_update_package_manager_data(self): if os.environ.get('TRAVIS') is None: res = self.pm.update_package_manager_data( self.config.package_manager_urls) self.assertTrue(isinstance(res, PackageManagerData)) res = self.db.delete_data(res.id, 'package_manager_data') self.assertTrue(res) class TestSendGridEmail(unittest.TestCase): def setUp(self): if os.environ.get('TRAVIS') is None: self.sg = SendGrid() self.config = Config() def test_send_email(self): if os.environ.get('TRAVIS') is None: res = self.sg.send_email( 'elmer.thomas+test@sendgrid.com', self.config.from_email, self.config.email_subject, self.config.email_body ) self.assertEqual(202, res[0]) class TestExportTable(unittest.TestCase): # Corresponds to schema in `db/data_schema.sql` header_row = "id,date_updated,language,pull_requests,open_issues,"\ "number_of_commits,number_of_branches,number_of_releases,"\ "number_of_contributors,number_of_watchers,"\ "number_of_stargazers,number_of_forks\n" def setUp(self): if os.environ.get('TRAVIS') is None: self.github = GitHub() self.db = DBConnector() self.config = Config() self.github.update_library_data(self.config.github_user, self.config.github_repos[0]) self.filename = "./csv/{}.csv".format(GitHubData.__tablename__) def test_file_export_succeeds(self): if os.environ.get('TRAVIS') is None: self.assertFalse(os.path.exists(self.filename)) self.db.export_table_to_csv(GitHubData) self.assertTrue(os.path.exists(self.filename)) def test_file_export_has_correct_data(self): if os.environ.get('TRAVIS') is None: self.db.export_table_to_csv(GitHubData) with open(self.filename, 'r') as fp: exported_data = fp.readlines() # Table has correct header self.assertEqual(exported_data[0], self.header_row) # Table exported correct number of rows num_exported_rows = len(exported_data) - 1 # exclude header num_db_rows = len(self.db.get_data(GitHubData)) self.assertEqual(num_exported_rows, num_db_rows) def tearDown(self): if os.environ.get('TRAVIS') is None: os.remove(self.filename) class TestLicenseYear(unittest.TestCase): def setUp(self): self.license_file = 'LICENSE.txt' def test_license_year(self): copyright_line = '' with open(self.license_file, 'r') as f: for line in f: if line.startswith('Copyright'): copyright_line = line.strip() break self.assertEqual('Copyright (c) 2016-%s SendGrid, Inc.' % datetime.datetime.now().year, copyright_line) if __name__ == '__main__': unittest.main()