from __future__ import absolute_import from datetime import timedelta, date from itertools import chain from django.conf import settings from django.db.models import CharField, Func, F, Avg, DecimalField, ExpressionWrapper from django.test import TestCase from django.utils.encoding import force_text from django_pivot.histogram import histogram from django_pivot.pivot import pivot from .models import ShirtSales, Store, Region genders = ['B', 'G'] styles = ['Tee', 'Golf', 'Fancy'] dates = ['2004-12-24', '2005-01-31', '2005-02-01', '2005-02-02', '2005-03-01', '2005-03-02', '2005-04-03', '2005-05-06'] store_names = [ 'ABC Shirts', 'Shirt Emporium', 'Just Shirts', 'Shirts R Us', 'Shirts N More' ] class DateFormat(Func): function = 'DATE_FORMAT' template = '%(function)s(%(expressions)s, "%(format)s")' def __init__(self, *expressions, **extra): strf = extra.pop('format', None) extra['format'] = strf.replace("%", "%%") extra['output_field'] = CharField() super(DateFormat, self).__init__(*expressions, **extra) class StrFtime(Func): function = 'strftime' template = '%(function)s("%(format)s", %(expressions)s)' def __init__(self, *expressions, **extra): strf = extra.pop('format', None) extra['format'] = strf.replace("%", "%%") extra['output_field'] = CharField() super(StrFtime, self).__init__(*expressions, **extra) class Tests(TestCase): @classmethod def setUpClass(cls): super(Tests, cls).setUpClass() # Generate a bunch of data to pivot Region(name='North').save() Region(name='South').save() Region(name='East').save() Region(name='West').save() regions = list(Region.objects.all()) Store(name='ABC Shirts', region=regions[0]).save() Store(name='Shirt Emporium', region=regions[1]).save() Store(name='Just Shirts', region=regions[2]).save() Store(name='Shirts R Us', region=regions[3]).save() Store(name='Shirts N More', region=regions[0]).save() units = [12, 9, 10, 15, 13, 9, 15, 3, 7] prices = [11.04, 13.00, 11.96, 11.27, 12.12, 13.74, 11.44, 12.63, 12.06, 13.42, 11.48] shirt_sales = [ShirtSales(store=store, gender=g, style=s, shipped=d) for store in Store.objects.all() for g in genders for s in styles for d in dates] for indx, shirt_sale in enumerate(shirt_sales): shirt_sale.units = units[indx % len(units)] shirt_sale.price = prices[indx % len(prices)] shirt_sales.append(ShirtSales(store=Store.objects.first(), gender=genders[0], style=styles[0], shipped='2005-07-05', units=13, price=73 )) ShirtSales.objects.bulk_create(shirt_sales) def test_pivot(self): shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales.objects.all(), 'style', 'gender', 'units') for row in pt: style = row['style'] for gender in genders: gender_display = 'Boy' if gender == 'B' else 'Girl' self.assertEqual(row[gender_display], sum(ss.units for ss in shirt_sales if ss.style == style and ss.gender == gender)) def test_pivot_on_choice_field_row(self): shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales.objects.all(), 'gender', 'style', 'units') for row in pt: gender = row['gender'] for style in styles: self.assertEqual(row[style], sum(ss.units for ss in shirt_sales if force_text(ss.gender) == force_text(gender) and ss.style == style)) def test_pivot_on_date(self): shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales, 'style', 'shipped', 'units', default=0) for row in pt: style = row['style'] for dt in dates: self.assertEqual(row[dt], sum(ss.units for ss in shirt_sales if ss.style == style and force_text(ss.shipped) == dt)) pt = pivot(ShirtSales.objects, 'shipped', 'style', 'units', default=0) for row in pt: shipped = row['shipped'] for style in styles: self.assertEqual(row[style], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.style == style)) def test_pivot_on_foreignkey(self): shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales, 'shipped', 'store__region__name', 'units', default=0) for row in pt: shipped = row['shipped'] for name in ['North', 'South', 'East', 'West']: self.assertEqual(row[name], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.store.region.name == name)) pt = pivot(ShirtSales, 'shipped', 'store__name', 'units', default=0) for row in pt: shipped = row['shipped'] for name in store_names: self.assertEqual(row[name], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.store.name == name)) def test_monthly_report(self): if settings.BACKEND == 'mysql': annotations = { 'Month': DateFormat('shipped', format='%m-%Y'), 'date_sort': DateFormat('shipped', format='%Y-%m') } elif settings.BACKEND == 'sqlite': annotations = { 'Month': StrFtime('shipped', format='%m-%Y'), 'date_sort': StrFtime('shipped', format='%Y-%m') } else: return shirt_sales = ShirtSales.objects.annotate(**annotations).order_by('date_sort') monthly_report = pivot(shirt_sales, 'Month', 'store__name', 'units', default=0) # Get the months and assert that the order by that we sent in is respected months = [record['Month'] for record in monthly_report] month_strings = ['12-2004', '01-2005', '02-2005', '03-2005', '04-2005', '05-2005', '07-2005'] self.assertEqual(months, month_strings) # Check that the aggregations are correct too for record in monthly_report: month, year = record['Month'].split('-') for name in store_names: self.assertEqual(record[name], sum(ss.units for ss in shirt_sales if (int(ss.shipped.year) == int(year) and int(ss.shipped.month) == int(month) and ss.store.name == name))) def test_pivot_with_default_fill(self): shirt_sales = ShirtSales.objects.filter(shipped__gt='2005-01-25', shipped__lt='2005-02-03') row_range = [date(2005, 1, 25) + timedelta(days=n) for n in range(14)] pt = pivot(shirt_sales, 'shipped', 'style', 'units', default=0, row_range=row_range) for row in pt: shipped = row['shipped'] for style in styles: self.assertEqual(row[style], sum(ss.units for ss in shirt_sales if force_text(ss.shipped) == force_text(shipped) and ss.style == style)) def test_pivot_aggregate(self): shirt_sales = ShirtSales.objects.all() data = ExpressionWrapper(F('units') * F('price'), output_field=DecimalField()) pt = pivot(ShirtSales, 'store__region__name', 'shipped', data, Avg, default=0) for row in pt: region_name = row['store__region__name'] for dt in (key for key in row.keys() if key != 'store__region__name'): spends = [ss.units * ss.price for ss in shirt_sales if force_text(ss.shipped) == force_text(dt) and ss.store.region.name == region_name] avg = sum(spends) / len(spends) if spends else 0 self.assertAlmostEqual(row[dt], float(avg), places=4) def test_pivot_display_transform(self): def display_transform(string): return 'prefix_' + string shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales.objects.all(), 'style', 'gender', 'units', display_transform=display_transform) for row in pt: style = row['style'] for gender in genders: gender_display = display_transform('Boy' if gender == 'B' else 'Girl') self.assertEqual(row[gender_display], sum(ss.units for ss in shirt_sales if ss.style == style and ss.gender == gender)) def test_pivot_multiple_rows(self): shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales.objects.all(), ('style', 'store'), 'gender', 'units') for row in pt: style = row['style'] store = row['store'] for gender in genders: gender_display = 'Boy' if gender == 'B' else 'Girl' self.assertEqual(row[gender_display], sum(ss.units for ss in shirt_sales if ss.style == style and ss.gender == gender and ss.store_id == store)) def test_pivot_on_choice_field_row_with_multiple_rows(self): shirt_sales = ShirtSales.objects.all() pt = pivot(ShirtSales.objects.all(), ('gender', 'store'), 'style', 'units') pt_reverse_rows = pivot(ShirtSales.objects.all(), ('store', 'gender'), 'style', 'units') for row in chain(pt, pt_reverse_rows): gender = row['gender'] store = row['store'] self.assertIn('get_gender_display', row) for style in styles: self.assertEqual(row[style], sum(ss.units for ss in shirt_sales if force_text(ss.gender) == force_text(gender) and ss.style == style and ss.store_id == store)) def test_histogram(self): hist = histogram(ShirtSales, 'units', bins=[0, 10, 15]) expected = [{'bin': '0', 'units': 0}, {'bin': '10', 'units': 0}, {'bin': '15', 'units': 0}] for s in ShirtSales.objects.all(): if s.units < 10: expected[0]['units'] += 1 elif s.units < 15: expected[1]['units'] += 1 else: expected[2]['units'] += 1 self.assertEqual(hist, expected) def test_multi_histogram(self): hist = histogram(ShirtSales, 'units', bins=[0, 10, 15], slice_on='gender') expected = [{'bin': '0', 'Boy': 0, 'Girl': 0}, {'bin': '10', 'Boy': 0, 'Girl': 0}, {'bin': '15', 'Boy': 0, 'Girl': 0}] for s in ShirtSales.objects.all(): if s.units < 10: if s.gender == 'B': expected[0]['Boy'] += 1 if s.gender == 'G': expected[0]['Girl'] += 1 elif s.units < 15: if s.gender == 'B': expected[1]['Boy'] += 1 if s.gender == 'G': expected[1]['Girl'] += 1 else: if s.gender == 'B': expected[2]['Boy'] += 1 if s.gender == 'G': expected[2]['Girl'] += 1 self.assertEqual(list(hist), expected) def test_histograms_with_zeros(self): hist = histogram(ShirtSales, 'units', bins=[0, 1, 2, 3, 10, 14, 15, 100, 150], slice_on='gender') # The first 3 buckets have all zero values. self.assertEqual(hist[0], {'bin': '0', 'Boy': 0, 'Girl': 0}) self.assertEqual(hist[1], {'bin': '1', 'Boy': 0, 'Girl': 0}) self.assertEqual(hist[2], {'bin': '2', 'Boy': 0, 'Girl': 0}) # A bucket in the middle has zeros self.assertEqual(hist[5], {'bin': '14', 'Boy': 0, 'Girl': 0}) # The last 3 buckets have zero values self.assertEqual(hist[7], {'bin': '100', 'Boy': 0, 'Girl': 0}) self.assertEqual(hist[8], {'bin': '150', 'Boy': 0, 'Girl': 0})