from django.db import models from psqlextra.fields import HStoreField from .fake_model import get_fake_model def test_upsert(): """Tests whether simple upserts works correctly.""" model = get_fake_model( { "title": HStoreField(uniqueness=["key1"]), "cookies": models.CharField(max_length=255, null=True), } ) obj1 = model.objects.upsert_and_get( conflict_target=[("title", "key1")], fields=dict(title={"key1": "beer"}, cookies="cheers"), ) obj1.refresh_from_db() assert obj1.title["key1"] == "beer" assert obj1.cookies == "cheers" obj2 = model.objects.upsert_and_get( conflict_target=[("title", "key1")], fields=dict(title={"key1": "beer"}, cookies="choco"), ) obj1.refresh_from_db() obj2.refresh_from_db() # assert both objects are the same assert obj1.id == obj2.id assert obj1.title["key1"] == "beer" assert obj1.cookies == "choco" assert obj2.title["key1"] == "beer" assert obj2.cookies == "choco" def test_upsert_explicit_pk(): """Tests whether upserts works when the primary key is explicitly specified.""" model = get_fake_model( { "name": models.CharField(max_length=255, primary_key=True), "cookies": models.CharField(max_length=255, null=True), } ) obj1 = model.objects.upsert_and_get( conflict_target=[("name")], fields=dict(name="the-object", cookies="first-cheers"), ) obj1.refresh_from_db() assert obj1.name == "the-object" assert obj1.cookies == "first-cheers" obj2 = model.objects.upsert_and_get( conflict_target=[("name")], fields=dict(name="the-object", cookies="second-boo"), ) obj1.refresh_from_db() obj2.refresh_from_db() # assert both objects are the same assert obj1.pk == obj2.pk assert obj1.name == "the-object" assert obj1.cookies == "second-boo" assert obj2.name == "the-object" assert obj2.cookies == "second-boo" def test_upsert_bulk(): """Tests whether bulk_upsert works properly.""" model = get_fake_model( { "first_name": models.CharField( max_length=255, null=True, unique=True ), "last_name": models.CharField(max_length=255, null=True), } ) model.objects.bulk_upsert( conflict_target=["first_name"], rows=[ dict(first_name="Swen", last_name="Kooij"), dict(first_name="Henk", last_name="Test"), ], ) row_a = model.objects.get(first_name="Swen") row_b = model.objects.get(first_name="Henk") model.objects.bulk_upsert( conflict_target=["first_name"], rows=[ dict(first_name="Swen", last_name="Test"), dict(first_name="Henk", last_name="Kooij"), ], ) row_a.refresh_from_db() assert row_a.last_name == "Test" row_b.refresh_from_db() assert row_b.last_name == "Kooij" def test_upsert_bulk_no_rows(): """Tests whether bulk_upsert doesn't crash when specifying no rows or a falsy value.""" model = get_fake_model( {"name": models.CharField(max_length=255, null=True, unique=True)} ) model.objects.bulk_upsert(conflict_target=["name"], rows=[]) model.objects.bulk_upsert(conflict_target=["name"], rows=None) def test_bulk_upsert_return_models(): """Tests whether models are returned instead of dictionaries when specifying the return_model=True argument.""" model = get_fake_model( { "id": models.BigAutoField(primary_key=True), "name": models.CharField(max_length=255, unique=True), } ) rows = [dict(name="John Smith"), dict(name="Jane Doe")] objs = model.objects.bulk_upsert( conflict_target=["name"], rows=rows, return_model=True ) for index, obj in enumerate(objs, 1): assert isinstance(obj, model) assert obj.id == index def test_bulk_upsert_accepts_getitem_iterable(): """Tests whether an iterable only implementing the __getitem__ method works correctly.""" class GetItemIterable: def __init__(self, items): self.items = items def __getitem__(self, key): return self.items[key] model = get_fake_model( { "id": models.BigAutoField(primary_key=True), "name": models.CharField(max_length=255, unique=True), } ) rows = GetItemIterable([dict(name="John Smith"), dict(name="Jane Doe")]) objs = model.objects.bulk_upsert( conflict_target=["name"], rows=rows, return_model=True ) for index, obj in enumerate(objs, 1): assert isinstance(obj, model) assert obj.id == index def test_bulk_upsert_accepts_iter_iterable(): """Tests whether an iterable only implementing the __iter__ method works correctly.""" class IterIterable: def __init__(self, items): self.items = items def __iter__(self): return iter(self.items) model = get_fake_model( { "id": models.BigAutoField(primary_key=True), "name": models.CharField(max_length=255, unique=True), } ) rows = IterIterable([dict(name="John Smith"), dict(name="Jane Doe")]) objs = model.objects.bulk_upsert( conflict_target=["name"], rows=rows, return_model=True ) for index, obj in enumerate(objs, 1): assert isinstance(obj, model) assert obj.id == index