try: from builtins import range except ImportError: pass import unittest import tempfile import os from flask_blogging.sqlastorage import SQLAStorage from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from test import FlaskBloggingTestCase import sqlalchemy as sqla from flask_sqlalchemy import SQLAlchemy import time try: import _mysql HAS_MYSQL = True except ImportError: HAS_MYSQL = False try: import psycopg2 HAS_POSTGRES = True except ImportError: HAS_POSTGRES = False class TestSQLiteStorage(FlaskBloggingTestCase): def _create_storage(self): temp_dir = tempfile.gettempdir() self._dbfile = os.path.join(temp_dir, "temp.db") self._engine = create_engine('sqlite:///'+self._dbfile) self._meta = sqla.MetaData() self.storage = SQLAStorage(self._engine, metadata=self._meta) self._meta.create_all(bind=self._engine) def setUp(self): FlaskBloggingTestCase.setUp(self) self._create_storage() def tearDown(self): os.remove(self._dbfile) def test_post_table_exists(self): table_name = "post" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['id', 'title', 'text', 'post_date', 'last_modified_date', 'draft'] self.assertListEqual(columns, expected_columns) def test_tag_table_exists(self): table_name = "tag" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['id', 'text'] self.assertListEqual(columns, expected_columns) def test_tag_post_table_exists(self): table_name = "tag_posts" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['tag_id', 'post_id'] self.assertListEqual(columns, expected_columns) def test_user_post_table_exists(self): table_name = "user_posts" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['user_id', 'post_id'] self.assertListEqual(columns, expected_columns) def test_user_post_table_consistency(self): # check if the user post table updates the user_id user_id = 1 post_id = 5 pid = self.storage.save_post(title="Title", text="Sample Text", user_id="user", tags=["hello", "world"]) posts = self.storage.get_posts() self.assertEqual(len(posts), 1) self.storage.save_post(title="Title", text="Sample Text", user_id="newuser", tags=["hello", "world"], post_id=pid) self.assertEqual(len(posts), 1) return def _test_user_post_model_consistency(self): # check if the user post table updates the user_id - DISABLED user_id = 1 post_id = 5 pid = self.storage.save_post(title="Title", text="Sample Text", user_id="user", tags=["hello", "world"]) with self._engine.begin() as conn: from flask_blogging.sqlastorage import Post Session = sessionmaker() Session.configure(bind=self._engine) session = Session() posts = session.query(Post).all() self.assertEqual(len(posts), 1) post = posts[0] self.assertEqual(post.title, "Title") self.assertEqual(post.text, "Sample Text") self.assertEqual(post.id, pid) return def test_tags_uniqueness(self): table_name = "tag" metadata = self._meta table = metadata.tables[table_name] with self._engine.begin() as conn: statement = table.insert().values(text="test_tag") conn.execute(statement) # reentering same tag should raise exception with self._engine.begin() as conn: statement = table.insert().values(text="test_tag") self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement) def test_tags_consistency(self): # check that when tag is updated, the posts get updated tags = ["hello", "world"] pid = self.storage.save_post(title="Title", text="Sample Text", user_id="user", tags=tags) post = self.storage.get_post_by_id(pid) self.assertEqual(len(post["tags"]), 2) tags.pop() pid = self.storage.save_post(title="Title", text="Sample Text", user_id="user", tags=tags, post_id=pid) post = self.storage.get_post_by_id(pid) self.assertEqual(len(post["tags"]), 1) def test_tag_post_uniqueness(self): self.storage.save_post(title="Title", text="Sample Text", user_id="user", tags=["tags"]) table_name = "tag_posts" metadata = self._meta table = metadata.tables[table_name] with self._engine.begin() as conn: statement = table.insert().values(tag_id=1, post_id=1) self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement) def test_user_post_uniqueness(self): pid = self.storage.save_post(title="Title1", text="Sample Text", user_id="testuser", tags=["hello", "world"]) table_name = "user_posts" metadata = sqla.MetaData() metadata.reflect(bind=self._engine) table = metadata.tables[table_name] # reentering same user should raise exception with self._engine.begin() as conn: statement = table.insert().values(user_id="testuser", post_id=pid) self.assertRaises(sqla.exc.IntegrityError, conn.execute, statement) def test_bind_database(self): # self.storage._create_all_tables() self.test_post_table_exists() self.test_tag_table_exists() self.test_tag_post_table_exists() self.test_user_post_table_exists() def test_save_post(self): pid = self.storage.save_post(title="Title1", text="Sample Text", user_id="testuser", tags=["hello", "world"]) pid = self.storage.save_post(title="Title1", text="Sample Text", user_id="testuser", tags=["hello", "world"], post_id=1) p = self.storage.get_post_by_id(2) self.assertIsNone(p) # invalid post_id will be treated as inserts pid = self.storage.save_post(title="Title1", text="Sample Text", user_id="testuser", tags=["hello", "world"], post_id=5) self.assertNotEqual(pid, 5) self.assertEqual(pid, 2) p = self.storage.get_post_by_id(2) self.assertIsNotNone(p) def test_delete_post(self): # insert, check exists, delete, check doesn't exist anymore pid = self.storage.save_post(title="Title1", text="Sample Text", user_id="testuser", tags=["hello", "world"]) p = self.storage.get_post_by_id(pid) self.assertIsNotNone(p) self.storage.delete_post(pid) p = self.storage.get_post_by_id(pid) self.assertIsNone(p) # insert again. pid = self.storage.save_post(title="Title1", text="Sample Text", user_id="testuser", tags=["hello", "world"], post_id=1) p = self.storage.get_post_by_id(pid) self.assertIsNotNone(p) def test_get_post_by_id(self): pid1 = self.storage.save_post(title="Title1", text="Sample Text1", user_id="testuser", tags=["hello", "world"]) pid2 = self.storage.save_post(title="Title2", text="Sample Text2", user_id="testuser", tags=["hello", "my", "world"]) post = self.storage.get_post_by_id(pid1) self._assert_post(post, "Title1", "Sample Text1", "testuser", ["HELLO", "WORLD"]) post = self.storage.get_post_by_id(pid2) self._assert_post(post, "Title2", "Sample Text2", "testuser", ["HELLO", "MY", "WORLD"]) def _assert_post(self, post, title, text, user_id, tags): tags = set([t.upper() for t in tags]) self.assertSetEqual(set(post["tags"]), tags) self.assertEqual(post["title"], title) self.assertEqual(post["text"], text) self.assertEqual(post["user_id"], user_id) def test_get_posts(self): self._create_dummy_data() # test default queries posts = self.storage.get_posts() self.assertEqual(len(posts), 10) ctr = 19 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "newuser", ["world"]) ctr -= 1 posts = self.storage.get_posts(recent=False) self.assertEqual(len(posts), 10) ctr = 0 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "testuser", ["hello"]) ctr += 1 # test count and offset posts = self.storage.get_posts(count=5, offset=5, recent=False) self.assertEqual(len(posts), 5) ctr = 5 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "testuser", ["hello"]) ctr += 1 # test tag feature posts = self.storage.get_posts(tag="hello", recent=False) self.assertEqual(len(posts), 10) ctr = 0 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "testuser", ["hello"]) ctr += 1 posts = self.storage.get_posts(tag="world", recent=False) self.assertEqual(len(posts), 10) ctr = 10 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "newuser", ["world"]) ctr += 1 # test user_id feature posts = self.storage.get_posts(user_id="newuser", recent=True) self.assertEqual(len(posts), 10) ctr = 19 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "newuser", ["world"]) ctr -= 1 posts = self.storage.get_posts(user_id="testuser", recent=True) self.assertEqual(len(posts), 10) ctr = 9 for post in posts: self._assert_post(post, "Title%d" % ctr, "Sample Text%d" % ctr, "testuser", ["hello"]) ctr -= 1 return def test_count_posts(self): self._create_dummy_data() count = self.storage.count_posts() self.assertEqual(count, 20) # test user count = self.storage.count_posts(user_id="testuser") self.assertEqual(count, 10) count = self.storage.count_posts(user_id="newuser") self.assertEqual(count, 10) count = self.storage.count_posts(user_id="testuser") self.assertEqual(count, 10) # test tags count = self.storage.count_posts(tag="hello") self.assertEqual(count, 10) count = self.storage.count_posts(tag="world") self.assertEqual(count, 10) # multiple queries count = self.storage.count_posts(user_id="testuser", tag="world") self.assertEqual(count, 0) def _create_dummy_data(self): for i in range(20): tags = ["hello"] if i < 10 else ["world"] user = "testuser" if i < 10 else "newuser" self.storage.save_post(title="Title%d" % i, text="Sample Text%d" % i, user_id=user, tags=tags) time.sleep(1) @unittest.skipUnless(HAS_MYSQL, "Package mysql-python needs to be install to " "run this test.") class TestMySQLStorage(TestSQLiteStorage): def _create_storage(self): self._engine = create_engine( "mysql+mysqldb://root:@localhost/flask_blogging") self._meta = sqla.MetaData() self.storage = SQLAStorage(self._engine, metadata=self._meta) self._meta.create_all(bind=self._engine) def tearDown(self): metadata = sqla.MetaData() metadata.reflect(bind=self._engine) metadata.drop_all(bind=self._engine) @unittest.skipUnless(HAS_POSTGRES, "Requires psycopg2 Postgres package") class TestPostgresStorage(TestSQLiteStorage): def _create_storage(self): self._engine = create_engine( "postgresql+psycopg2://postgres:@localhost/flask_blogging") self._meta = sqla.MetaData() self.storage = SQLAStorage(self._engine, metadata=self._meta) self._meta.create_all(bind=self._engine) def tearDown(self): metadata = sqla.MetaData() metadata.reflect(bind=self._engine) metadata.drop_all(bind=self._engine) def test_user_post_model_consistency(self): pass class TestSQLiteBinds(FlaskBloggingTestCase): def _conn_string(self, dbfile): return 'sqlite:///'+dbfile def setUp(self): FlaskBloggingTestCase.setUp(self) temp_dir = tempfile.gettempdir() self._dbfile = os.path.join(temp_dir, "temp.db") conn_string = self._conn_string(self._dbfile) self.app.config["SQLALCHEMY_BINDS"] = { 'blog': conn_string } self._db = SQLAlchemy(self.app) self.storage = SQLAStorage(db=self._db, bind="blog") self._engine = self._db.get_engine(self.app, bind="blog") self._meta = self._db.metadata self._db.create_all(bind=["blog"]) def tearDown(self): os.remove(self._dbfile) def test_post_table_exists(self): table_name = "post" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['id', 'title', 'text', 'post_date', 'last_modified_date', 'draft'] self.assertListEqual(columns, expected_columns) # test models from flask_blogging.sqlastorage import Post self.assertNotEqual(Post, None) def test_tag_table_exists(self): table_name = "tag" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['id', 'text'] self.assertListEqual(columns, expected_columns) # test models from flask_blogging.sqlastorage import Tag self.assertNotEqual(Tag, None) def test_tag_post_table_exists(self): table_name = "tag_posts" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['tag_id', 'post_id'] self.assertListEqual(columns, expected_columns) def test_user_post_table_exists(self): table_name = "user_posts" with self._engine.begin() as conn: self.assertTrue(conn.dialect.has_table(conn, table_name)) metadata = self._meta table = metadata.tables[table_name] columns = [t.name for t in table.columns] expected_columns = ['user_id', 'post_id'] self.assertListEqual(columns, expected_columns) @unittest.skipUnless(HAS_MYSQL, "Package mysql-python needs to be install to " "run this test.") class TestMySQLBinds(TestSQLiteBinds): def _conn_string(self, dbfile): return "mysql+mysqldb://root:@localhost/flask_blogging" def tearDown(self): pass @unittest.skipUnless(HAS_POSTGRES, "Requires psycopg2 Postgres package") class TestPostgresBinds(TestSQLiteBinds): def _conn_string(self, dbfile): return "postgresql+psycopg2://postgres:@localhost/flask_blogging" def tearDown(self): pass