import random from datetime import date, time, timedelta from decimal import Decimal from unittest import skipUnless from django.conf import settings from django.db.models import F, Func, Value from django.db.models.functions import Concat from django.test import TestCase from django.utils import timezone from django_bulk_update import helper from .models import Person, Role, PersonUUID, Brand from .fixtures import create_fixtures class BulkUpdateTests(TestCase): def setUp(self): self.now = timezone.now().replace(microsecond=0) # mysql doesn't do microseconds. # NOQA self.date = date(2015, 3, 28) self.time = time(13, 0) create_fixtures() def _test_field(self, field, idx_to_value_function): ''' Helper to do repeative simple tests on one field. ''' # set people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): value = idx_to_value_function(idx) setattr(person, field, value) # update Person.objects.bulk_update(people, update_fields=[field]) # check people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): saved_value = getattr(person, field) expected_value = idx_to_value_function(idx) self.assertEqual(saved_value, expected_value) def test_simple_fields(self): fn = lambda idx: idx + 27 for field in ('default', 'big_age', 'age', 'positive_age', 'positive_small_age', 'small_age'): self._test_field(field, fn) def test_boolean_field(self): fn = lambda idx: [True, False][idx % 2] self._test_field('certified', fn) def test_null_boolean_field(self): fn = lambda idx: [True, False, None][idx % 3] self._test_field('null_certified', fn) def test_char_field(self): NAMES = ['Walter', 'The Dude', 'Donny', 'Jesus', 'Buddha', 'Clark'] fn = lambda idx: NAMES[idx % 5] self._test_field('name', fn) def test_email_field(self): EMAILS = ['walter@mailinator.com', 'thedude@mailinator.com', 'donny@mailinator.com', 'jesus@mailinator.com', 'buddha@mailinator.com', 'clark@mailinator.com'] fn = lambda idx: EMAILS[idx % 5] self._test_field('email', fn) def test_file_path_field(self): PATHS = ['/home/dummy.txt', '/Downloads/kitten.jpg', '/Users/user/fixtures.json', 'dummy.png', 'users.json', '/home/dummy.png'] fn = lambda idx: PATHS[idx % 5] self._test_field('file_path', fn) def test_slug_field(self): SLUGS = ['jesus', 'buddha', 'clark', 'the-dude', 'donny', 'walter'] fn = lambda idx: SLUGS[idx % 5] self._test_field('slug', fn) def test_text_field(self): TEXTS = ['this is a dummy text', 'dummy text', 'bla bla bla bla bla', 'here is a dummy text', 'dummy', 'bla bla bla'] fn = lambda idx: TEXTS[idx % 5] self._test_field('text', fn) def test_url_field(self): URLS = ['docs.djangoproject.com', 'news.ycombinator.com', 'https://docs.djangoproject.com', 'https://google.com', 'google.com', 'news.ycombinator.com'] fn = lambda idx: URLS[idx % 5] self._test_field('url', fn) def test_date_time_field(self): fn = lambda idx: self.now - timedelta(days=1 + idx, hours=1 + idx) self._test_field('date_time', fn) def test_date_field(self): fn = lambda idx: self.date - timedelta(days=1 + idx) self._test_field('date', fn) def test_time_field(self): fn = lambda idx: time(1 + idx, idx) self._test_field('time', fn) def test_decimal_field(self): fn = lambda idx: Decimal('1.%s' % (50 + idx * 7)) self._test_field('height', fn) def test_float_field(self): fn = lambda idx: float(idx) * 2.0 self._test_field('float_height', fn) def test_data_field(self): fn = lambda idx: {'x': idx} self._test_field('data', fn) def test_generic_ipaddress_field(self): IPS = ['127.0.0.1', '192.0.2.30', '2a02:42fe::4', '10.0.0.1', '8.8.8.8'] fn = lambda idx: IPS[idx % 5] self._test_field('remote_addr', fn) def test_image_field(self): IMGS = ['kitten.jpg', 'dummy.png', 'user.json', 'dummy.png', 'foo.gif'] fn = lambda idx: IMGS[idx % 5] self._test_field('image', fn) self._test_field('my_file', fn) def test_custom_fields(self): values = {} people = Person.objects.all() people_dict = {p.name: p for p in people} person = people_dict['Mike'] person.data = {'name': 'mikey', 'age': 99, 'ex': -99} values[person.pk] = {'name': 'mikey', 'age': 99, 'ex': -99} person = people_dict['Mary'] person.data = {'names': {'name': []}} values[person.pk] = {'names': {'name': []}} person = people_dict['Pete'] person.data = [] values[person.pk] = [] person = people_dict['Sandra'] person.data = [{'name': 'Pete'}, {'name': 'Mike'}] values[person.pk] = [{'name': 'Pete'}, {'name': 'Mike'}] person = people_dict['Ash'] person.data = {'text': 'bla'} values[person.pk] = {'text': 'bla'} person = people_dict['Crystal'] values[person.pk] = person.data Person.objects.bulk_update(people) people = Person.objects.all() for person in people: self.assertEqual(person.data, values[person.pk]) def test_update_fields(self): """ Only the fields in "update_fields" are updated """ people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') Person.objects.bulk_update(people, update_fields=['age']) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertEqual(person1.age, person2.age) self.assertNotEqual(person1.height, person2.height) def test_update_foreign_key_fields(self): roles = [Role.objects.create(code=1), Role.objects.create(code=2)] people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') person.role = roles[0] if idx % 2 == 0 else roles[1] Person.objects.bulk_update(people) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertEqual(person1.role.code, person2.role.code) self.assertEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) def test_update_foreign_key_fields_explicit(self): roles = [Role.objects.create(code=1), Role.objects.create(code=2)] people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') person.role = roles[0] if idx % 2 == 0 else roles[1] person.big_age += 40 Person.objects.bulk_update(people, update_fields=['age', 'height', 'role']) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertEqual(person1.role.code, person2.role.code) self.assertEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) self.assertNotEqual(person1.big_age, person2.big_age) def test_update_foreign_key_fields_explicit_with_id_suffix(self): roles = [Role.objects.create(code=1), Role.objects.create(code=2)] people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') person.role = roles[0] if idx % 2 == 0 else roles[1] Person.objects.bulk_update(people, update_fields=['age', 'height', 'role_id']) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertEqual(person1.role.code, person2.role.code) self.assertEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) def test_update_foreign_key_exclude_fields_explicit(self): roles = [Role.objects.create(code=1), Role.objects.create(code=2)] people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') person.role = roles[0] if idx % 2 == 0 else roles[1] person.big_age += 40 Person.objects.bulk_update(people, update_fields=['age', 'height'], exclude_fields=['role']) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertTrue(isinstance(person1.role, Role)) self.assertEqual(person2.role, None) self.assertEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) self.assertNotEqual(person1.big_age, person2.big_age) def test_update_foreign_key_exclude_fields_explicit_with_id_suffix(self): roles = [Role.objects.create(code=1), Role.objects.create(code=2)] people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') person.role = roles[0] if idx % 2 == 0 else roles[1] Person.objects.bulk_update(people, update_fields=['age', 'height'], exclude_fields=['role_id']) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertTrue(isinstance(person1.role, Role)) self.assertEqual(person2.role, None) self.assertEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) def test_exclude_fields(self): """ Only the fields not in "exclude_fields" are updated """ people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') Person.objects.bulk_update(people, exclude_fields=['age']) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertNotEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) def test_exclude_fields_with_tuple_exclude_fields(self): """ Only the fields not in "exclude_fields" are updated """ people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') Person.objects.bulk_update(people, exclude_fields=('age',)) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertNotEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) def test_object_list(self): """ Pass in a list instead of a queryset for bulk updating """ people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.big_age = idx + 27 Person.objects.bulk_update(list(people)) people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): self.assertEqual(person.big_age, idx + 27) def test_empty_list(self): """ Update no elements, passed as a list """ Person.objects.bulk_update([]) def test_empty_queryset(self): """ Update no elements, passed as a queryset """ people = Person.objects.filter(name="Aceldotanrilsteucsebces ECSbd") Person.objects.bulk_update(people) def test_one_sized_list(self): """ Update one sized list, check if have a syntax error for some db backends. """ people = Person.objects.all()[:1] Person.objects.bulk_update(list(people)) def test_one_sized_queryset(self): """ Update one sized list, check if have a syntax error for some db backends. """ people = Person.objects.filter(name='Mike') Person.objects.bulk_update(people) def test_wrong_field_names(self): people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.big_age = idx + 27 self.assertRaises(TypeError, Person.objects.bulk_update, people, update_fields=['somecolumn', 'name']) people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.big_age = idx + 27 self.assertRaises(TypeError, Person.objects.bulk_update, people, exclude_fields=['somecolumn']) people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.big_age = idx + 27 self.assertRaises(TypeError, Person.objects.bulk_update, people, update_fields=['somecolumn'], exclude_fields=['someothercolumn']) def test_batch_size(self): people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age += 1 person.height += Decimal('0.01') updated_obj_count = Person.objects.bulk_update(people, batch_size=1) self.assertEqual(updated_obj_count, len(people)) people2 = Person.objects.order_by('pk').all() for person1, person2 in zip(people, people2): self.assertEqual(person1.age, person2.age) self.assertEqual(person1.height, person2.height) @skipUnless(settings.DATABASES['default']['USER'] == 'postgres', "ArrayField's are only available in PostgreSQL.") def test_array_field(self): """ Test to 'bulk_update' a postgresql's ArrayField. """ Brand.objects.bulk_create([ Brand(name='b1', codes=['a', 'b']), Brand(name='b2', codes=['x']), Brand(name='b3', codes=['x', 'y', 'z']), Brand(name='b4', codes=['1', '2']), ]) brands = Brand.objects.all() for brand in brands: brand.codes.append(brand.codes[0]*2) Brand.objects.bulk_update(brands) expected = ['aa', 'xx', 'xx', '11'] for value, brand in zip(expected, brands): self.assertEqual(brand.codes[-1], value) def test_uuid_pk(self): """ Test 'bulk_update' with a model whose pk is an uuid. """ # create PersonUUID.objects.bulk_create( [PersonUUID(age=c) for c in range(20, 30)]) # set people = PersonUUID.objects.order_by('pk').all() for idx, person in enumerate(people): person.age = idx * 11 # update PersonUUID.objects.bulk_update(people, update_fields=['age']) # check people = PersonUUID.objects.order_by('pk').all() for idx, person in enumerate(people): saved_value = person.age expected_value = idx * 11 self.assertEqual(saved_value, expected_value) def test_F_expresion(self): # initialize people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age = idx*10 person.save() # set people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): person.age = F('age') - idx # update Person.objects.bulk_update(people) # check people = Person.objects.order_by('pk').all() for idx, person in enumerate(people): saved_value = person.age expected_value = idx*10 - idx self.assertEqual(saved_value, expected_value) def test_Func_expresion(self): # initialize ini_values = 'aA', 'BB', '', 'cc', '12' people = Person.objects.order_by('pk').all() for value, person in zip(ini_values, people): person.name = value person.text = value*2 person.save() # set people = Person.objects.order_by('pk').all() for person in people: person.name = Func(F('name'), function='UPPER') person.text = Func(F('text'), function='LOWER') # update Person.objects.bulk_update(people) # check people = Person.objects.order_by('pk').all() expected_values = 'AA', 'BB', '', 'CC', '12' for expected_value, person in zip(expected_values, people): saved_value = person.name self.assertEqual(saved_value, expected_value) expected_values = 'aaaa', 'bbbb', '', 'cccc', '1212' for expected_value, person in zip(expected_values, people): saved_value = person.text self.assertEqual(saved_value, expected_value) def test_Concat_expresion(self): # initialize ini_values_1 = 'a', 'b', 'c', 'd', 'e' ini_values_2 = 'v', 'w', 'x', 'y', 'z' people = Person.objects.order_by('pk').all() for value1, value2, person in zip(ini_values_1, ini_values_2, people): person.slug = value1 person.name = value2 person.save() # set people = Person.objects.order_by('pk').all() for person in people: person.text = Concat(F('slug'), Value('@'), F('name'), Value('|')) # update Person.objects.bulk_update(people) # check people = Person.objects.order_by('pk').all() expected_values = 'a@v|', 'b@w|', 'c@x|', 'd@y|', 'e@z|' for expected_value, person in zip(expected_values, people): saved_value = person.text self.assertEqual(saved_value, expected_value) def test_different_deferred_fields(self): # initialize people = Person.objects.order_by('pk').all() for person in people: person.name = 'original name' person.text = 'original text' person.save() # set people1 = list(Person.objects.filter(age__lt=10).only('name')) people2 = list(Person.objects.filter(age__gte=10).only('text')) people = people1 + people2 for person in people: if person.age < 10: person.name = 'changed name' else: person.text = 'changed text' # update count = Person.objects.bulk_update(people) # check people = Person.objects.order_by('pk').all() self.assertEquals(count, people.count()) for person in people: if person.age < 10: self.assertEquals(person.name, 'changed name') self.assertEquals(person.text, 'original text') else: self.assertEquals(person.name, 'original name') self.assertEquals(person.text, 'changed text') def test_different_deferred_fields_02(self): # initialize people = Person.objects.order_by('pk').all() for person in people: person.name = 'original name' person.text = 'original text' person.save() # set people1 = list(Person.objects.filter(age__lt=10).only('name')) people2 = list(Person.objects.filter(age__gte=10).only('text')) people = people1 + people2 for person in people: if person.age < 10: person.name = 'changed name' else: person.text = 'changed text' # update count = Person.objects.bulk_update(people, exclude_fields=['name']) # check people = Person.objects.order_by('pk').all() self.assertEquals(count, people.count()) for person in people: if person.age < 10: self.assertEquals(person.name, 'original name') self.assertEquals(person.text, 'original text') else: self.assertEquals(person.name, 'original name') self.assertEquals(person.text, 'changed text') class NumQueriesTest(TestCase): def setUp(self): create_fixtures(5) def test_num_queries(self): """ Queries: - retrieve objects - update objects """ people = Person.objects.order_by('pk').all() self.assertNumQueries(2, Person.objects.bulk_update, people) def test_already_evaluated_queryset(self): """ Queries: - update objects (objects are already retrieved, because of the previous loop) """ people = Person.objects.all() for person in people: person.age += 2 person.name = Func(F('name'), function='UPPER') person.text = 'doc' person.height -= Decimal(0.5) self.assertNumQueries(1, Person.objects.bulk_update, people) def test_explicit_fields(self): """ Queries: - retrieve objects - update objects """ people = Person.objects.all() self.assertNumQueries( 2, Person.objects.bulk_update, people, update_fields=['date', 'time', 'image', 'slug', 'height'], exclude_fields=['date', 'url'] ) def test_deferred_fields(self): """ Queries: - retrieve objects - update objects """ people = Person.objects.all().only('date', 'url', 'age', 'image') self.assertNumQueries(2, Person.objects.bulk_update, people) def test_different_deferred_fields(self): """ Queries: - retrieve objects - update objects """ all_people = Person.objects people1 = all_people.filter(age__lt=10).defer('date', 'url', 'age') people2 = all_people.filter(age__gte=10).defer('url', 'name', 'big_age') people = people1 | people2 self.assertNumQueries(2, Person.objects.bulk_update, people) def test_deferred_fields_and_excluded_fields(self): """ Queries: - retrieve objects - update objects """ people = Person.objects.all().only('date', 'age', 'time', 'image', 'slug') self.assertNumQueries(2, Person.objects.bulk_update, people, exclude_fields=['date', 'url']) def test_list_of_objects(self): """ Queries: - update objects (objects are already retrieved, because of the cast to list) """ people = list(Person.objects.all()) self.assertNumQueries(1, Person.objects.bulk_update, people) def test_fields_to_update_are_deferred(self): """ As all fields in 'update_fields' are deferred, a query will be done for each obj and field to retrieve its value. """ people = Person.objects.all().only('pk') update_fields = ['date', 'time', 'image'] expected_queries = len(update_fields) * Person.objects.count() + 2 self.assertNumQueries(expected_queries, Person.objects.bulk_update, people, update_fields=update_fields) def test_no_field_to_update(self): """ Queries: - retrieve objects (as update_fields is empty, no update query will be done) """ people = Person.objects.all() self.assertNumQueries(1, Person.objects.bulk_update, people, update_fields=[]) def test_no_objects(self): """ Queries: - retrieve objects (as no objects is actually retrieved, no update query will be done) """ people = Person.objects.filter(name='xxx') self.assertNumQueries(1, Person.objects.bulk_update, people, update_fields=['age', 'height']) def test_batch_size(self): """ Queries: - retrieve objects - update objects * 3 """ self.assertEquals(Person.objects.count(), 5) people = Person.objects.order_by('pk').all() self.assertNumQueries(4, Person.objects.bulk_update, people, batch_size=2) class GetFieldsTests(TestCase): total_fields = 24 def setUp(self): create_fixtures() def _assertEquals(self, fields, names): self.assertEquals( set(field.name for field in fields), set(names), ) def _assertIn(self, names, fields): field_names = [field.name for field in fields] for name in names: self.assertIn(name, field_names) def _assertNotIn(self, names, fields): field_names = [field.name for field in fields] for name in names: self.assertNotIn(name, field_names) def test_get_all_fields(self): meta = Person.objects.first()._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self.assertEquals(len(fields), self.total_fields) def test_dont_get_primary_key(self): meta = Person.objects.first()._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertIn(['id'], meta.get_fields()) # sanity check self._assertNotIn(['id'], fields) # actual test meta = PersonUUID.objects.create(age=3)._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertIn(['uuid'], meta.get_fields()) # sanity check self._assertNotIn(['uuid'], fields) # actual test def test_dont_get_reversed_relations(self): meta = Person.objects.first()._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertIn(['companies'], meta.get_fields()) # sanity check self._assertNotIn(['companies'], fields) # actual test def test_dont_get_many_to_many_relations(self): meta = Person.objects.first()._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertIn(['jobs'], meta.get_fields()) # sanity check self._assertNotIn(['jobs'], fields) # actual test def test_update_fields(self): meta = Person.objects.first()._meta update_fields = ['age', 'email', 'text'] exclude_fields = [] fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertEquals(fields, ['age', 'email', 'text']) def test_update_fields_and_exclude_fields(self): meta = Person.objects.first()._meta update_fields = ['age', 'email', 'text'] exclude_fields = ['email', 'height'] fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertEquals(fields, ['age', 'text']) def test_empty_update_fields(self): meta = Person.objects.first()._meta update_fields = [] exclude_fields = ['email', 'height'] fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertEquals(fields, []) def test_exclude_a_foreignkey(self): meta = Person.objects.first()._meta update_fields = None exclude_fields = ['email', 'role'] fields = helper.get_fields(update_fields, exclude_fields, meta) self.assertEquals(len(fields), self.total_fields - 2) self._assertNotIn(['email', 'role'], fields) def test_exclude_foreignkey_with_id_suffix(self): meta = Person.objects.first()._meta update_fields = None exclude_fields = ['email', 'role_id'] fields = helper.get_fields(update_fields, exclude_fields, meta) self.assertEquals(len(fields), self.total_fields - 2) self._assertNotIn(['email', 'role'], fields) def test_get_a_foreignkey(self): meta = Person.objects.first()._meta update_fields = ['role', 'my_file'] exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertEquals(fields, ['role', 'my_file']) def test_get_foreignkey_with_id_suffix(self): meta = Person.objects.first()._meta update_fields = ['role_id', 'my_file'] exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertEquals(fields, ['role', 'my_file']) def test_obj_argument(self): obj = Person.objects.first() meta = obj._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta, obj) self.assertEquals(len(fields), self.total_fields) def test_only_get_not_deferred_fields(self): obj = Person.objects.only('name', 'age', 'height').first() meta = obj._meta update_fields = None exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta, obj) self._assertEquals(fields, ['name', 'age', 'height']) def test_only_and_exclude_fields(self): obj = Person.objects.only('name', 'age', 'height').first() meta = obj._meta update_fields = None exclude_fields = ['age', 'date'] fields = helper.get_fields(update_fields, exclude_fields, meta, obj) self._assertEquals(fields, ['name', 'height']) def test_only_and_exclude_fields_02(self): obj = Person.objects.defer('age', 'height').first() meta = obj._meta update_fields = None exclude_fields = ['image', 'data'] fields = helper.get_fields(update_fields, exclude_fields, meta, obj) self.assertEquals(len(fields), self.total_fields - 4) self._assertNotIn(['age', 'height', 'image', 'data'], fields) def test_update_fields_over_not_deferred_field(self): obj = Person.objects.only('name', 'age', 'height').first() meta = obj._meta update_fields = ['date', 'time', 'age'] exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta, obj) self._assertEquals(fields, ['date', 'time', 'age']) def test_update_fields_over_not_deferred_field_02(self): obj = Person.objects.only('name', 'age', 'height').first() meta = obj._meta update_fields = [] exclude_fields = None fields = helper.get_fields(update_fields, exclude_fields, meta, obj) self._assertEquals(fields, []) def test_arguments_as_tuples(self): meta = Person.objects.first()._meta update_fields = ('age', 'email', 'text') exclude_fields = ('email', 'height') fields = helper.get_fields(update_fields, exclude_fields, meta) self._assertEquals(fields, ['age', 'text']) def test_validate_fields(self): meta = Person.objects.first()._meta update_fields = ['age', 'wrong_name', 'text'] exclude_fields = ('email', 'height') self.assertRaises(TypeError, helper.get_fields, update_fields, exclude_fields, meta) update_fields = ('age', 'email', 'text') exclude_fields = ('email', 'bad_name') self.assertRaises(TypeError, helper.get_fields, update_fields, exclude_fields, meta) update_fields = ('companies', ) exclude_fields = None self.assertRaises(TypeError, helper.get_fields, update_fields, exclude_fields, meta) update_fields = None exclude_fields = ['jobs'] self.assertRaises(TypeError, helper.get_fields, update_fields, exclude_fields, meta)