import io
import os
import operator
import tempfile

from django.utils.six.moves import reduce
from django.core.files.storage import default_storage
from django.core.files.uploadedfile import SimpleUploadedFile
from django.db.models.fields.files import FieldFile
from django.test import SimpleTestCase, TestCase
from greenwich import raster
from PIL import Image

from spillway.models import upload_to

from .models import RasterStore

def create_image(multiband=False):
    tmpname = os.path.join(
        upload_to.path,
        os.path.basename(tempfile.mktemp(prefix='tmin_', suffix='.tif')))
    fp = default_storage.open(tmpname, 'w+b')
    shape = (5, 5)
    if multiband:
        shape += (3,)
    b = bytearray(range(reduce(operator.mul, shape)))
    ras = raster.frombytes(bytes(b), shape)
    ras.affine = (-120, 2, 0, 38, 0, -2)
    ras.sref = 4326
    ras.save(fp)
    ras.close()
    fp.seek(0)
    return fp


class RasterTestBase(SimpleTestCase):
    use_multiband = False

    def setUp(self):
        name = self.f.name.replace('%s/' % default_storage.location, '')
        ff = FieldFile(None, RasterStore._meta.get_field('image'), name)
        self.data = {'image': ff}

    @classmethod
    def setUpClass(cls):
        cls.f = create_image(cls.use_multiband)
        super(RasterTestBase, cls).setUpClass()

    @classmethod
    def tearDownClass(cls):
        cls.f.close()
        super(RasterTestBase, cls).tearDownClass()

    def _image(self, imgdata):
        return Image.open(io.BytesIO(imgdata))
        #return Image.open(imgdata)


class RasterStoreTestBase(RasterTestBase, TestCase):
    def setUp(self):
        super(RasterStoreTestBase, self).setUp()
        self.object = RasterStore.objects.create(image=self.data['image'].name)
        self.qs = RasterStore.objects.all()


class RasterStoreTestCase(RasterStoreTestBase):
    def test_array(self):
        point = self.object.geom.centroid.transform(3310, clone=True)
        self.assertEqual(self.object.array(point).squeeze(), 12)

    def test_save_uploadfile(self):
        upload = SimpleUploadedFile('up.tif', self.object.image.read())
        rstore = RasterStore(image=upload)
        rstore.save()
        self.assertTrue(default_storage.exists(rstore.image))
        self.assertEqual(rstore.image.size, self.f.size)

    def test_linear(self):
        self.assertEqual(list(self.object.linear()),
                         [0., 6., 12., 18., 24.])
        self.assertEqual(list(self.object.linear((2, 20))),
                         [2., 6.5, 11., 15.5, 20.])

    def test_quantiles(self):
        self.assertEqual(list(self.object.quantiles()),
                         [0., 6., 12., 18., 24.])