#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
#

from io import BytesIO
import random
import pytest
import decimal
import datetime
import pytz
import os
import platform
from snowflake.connector.arrow_context import ArrowConverterContext
from snowflake.connector.options import installed_pandas
from snowflake.connector.converter import (
    _generate_tzinfo_from_tzoffset)

try:
    import tzlocal
except ImportError:
    tzlocal = None

try:
    from pyarrow import RecordBatchStreamReader  # NOQA
    from pyarrow import RecordBatchStreamWriter  # NOQA
    from pyarrow import RecordBatch  # NOQA
    import pyarrow
except ImportError:
    pass

try:
    from snowflake.connector.arrow_iterator import PyArrowIterator  # NOQA
    from snowflake.connector.arrow_iterator import ROW_UNIT  # NOQA
    no_arrow_iterator_ext = False
except ImportError:
    no_arrow_iterator_ext = True


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_string_chunk():
    random.seed(datetime.datetime.now())
    column_meta = [
            {"logicalType": "TEXT"},
            {"logicalType": "TEXT"}
    ]
    field_foo = pyarrow.field("column_foo", pyarrow.string(), True, column_meta[0])
    field_bar = pyarrow.field("column_bar", pyarrow.string(), True, column_meta[1])
    pyarrow.schema([field_foo, field_bar])

    def str_generator():
        return str(random.randint(-100, 100))

    iterate_over_test_chunk([pyarrow.string(), pyarrow.string()],
                            column_meta, str_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_int64_chunk():
    random.seed(datetime.datetime.now())
    column_meta = [
            {"logicalType": "FIXED", "precision": "38", "scale": "0"},
            {"logicalType": "FIXED", "precision": "38", "scale": "0"}
    ]

    def int64_generator():
        return random.randint(-9223372036854775808, 9223372036854775807)

    iterate_over_test_chunk([pyarrow.int64(), pyarrow.int64()],
                            column_meta, int64_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_int32_chunk():
    random.seed(datetime.datetime.now())
    column_meta = [
            {"logicalType": "FIXED", "precision": "10", "scale": "0"},
            {"logicalType": "FIXED", "precision": "10", "scale": "0"}
    ]

    def int32_generator():
        return random.randint(-2147483648, 2147483637)

    iterate_over_test_chunk([pyarrow.int32(), pyarrow.int32()],
                            column_meta, int32_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_int16_chunk():
    random.seed(datetime.datetime.now())
    column_meta = [
            {"logicalType": "FIXED", "precision": "5", "scale": "0"},
            {"logicalType": "FIXED", "precision": "5", "scale": "0"}
    ]

    def int16_generator():
        return random.randint(-32768, 32767)

    iterate_over_test_chunk([pyarrow.int16(), pyarrow.int16()],
                            column_meta, int16_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_int8_chunk():
    random.seed(datetime.datetime.now())
    column_meta = [
            {"logicalType": "FIXED", "precision": "3", "scale": "0"},
            {"logicalType": "FIXED", "precision": "3", "scale": "0"}
    ]

    def int8_generator():
        return random.randint(-128, 127)

    iterate_over_test_chunk([pyarrow.int8(), pyarrow.int8()],
                            column_meta, int8_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_bool_chunk():
    random.seed(datetime.datetime.now())
    column_meta = {"logicalType": "BOOLEAN"}

    def bool_generator():
        return bool(random.getrandbits(1))

    iterate_over_test_chunk([pyarrow.bool_(), pyarrow.bool_()],
                            [column_meta, column_meta],
                            bool_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_float_chunk():
    random.seed(datetime.datetime.now())
    column_meta = [
            {"logicalType": "REAL"},
            {"logicalType": "FLOAT"}
    ]

    def float_generator():
        return random.uniform(-100.0, 100.0)

    iterate_over_test_chunk([pyarrow.float64(), pyarrow.float64()],
                            column_meta, float_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_decimal_chunk():
    random.seed(datetime.datetime.now())
    precision = random.randint(1, 38)
    scale = random.randint(0, precision)
    datatype = None
    if precision <= 2:
        datatype = pyarrow.int8()
    elif precision <= 4:
        datatype = pyarrow.int16()
    elif precision <= 9:
        datatype = pyarrow.int32()
    elif precision <= 19:
        datatype = pyarrow.int64()
    else:
        datatype = pyarrow.decimal128(precision, scale)

    def decimal_generator(_precision, _scale):
        def decimal128_generator(precision, scale):
            data = []
            for _ in range(precision):
                data.append(str(random.randint(0, 9)))

            if scale:
                data.insert(-scale, '.')
            return decimal.Decimal("".join(data))

        def int64_generator(precision):
            data = random.randint(-9223372036854775808, 9223372036854775807)
            return int(str(data)[:precision if data >= 0 else precision + 1])

        def int32_generator(precision):
            data = random.randint(-2147483648, 2147483637)
            return int(str(data)[:precision if data >= 0 else precision + 1])

        def int16_generator(precision):
            data = random.randint(-32768, 32767)
            return int(str(data)[:precision if data >= 0 else precision + 1])

        def int8_generator(precision):
            data = random.randint(-128, 127)
            return int(str(data)[:precision if data >= 0 else precision + 1])

        if _precision <= 2:
            return int8_generator(_precision)
        elif _precision <= 4:
            return int16_generator(_precision)
        elif _precision <= 9:
            return int32_generator(_precision)
        elif _precision <= 19:
            return int64_generator(_precision)
        else:
            return decimal128_generator(_precision, _scale)

    def expected_data_transform_decimal(_precision, _scale):
        def expected_data_transform_decimal_impl(data, precision=_precision, scale=_scale):
            if precision <= 19:
                return decimal.Decimal(data).scaleb(-scale)
            else:
                return data

        return expected_data_transform_decimal_impl

    column_meta = {"logicalType": "FIXED", "precision": str(precision), "scale": str(scale)}
    iterate_over_test_chunk([datatype, datatype], [column_meta, column_meta],
        lambda: decimal_generator(precision, scale), expected_data_transform_decimal(precision, scale))


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_date_chunk():
    random.seed(datetime.datetime.now())
    column_meta = {
        "byteLength": "4",
        "logicalType": "DATE",
        "precision": "38",
        "scale": "0",
        "charLength": "0"
    }

    def date_generator():
        return datetime.date.fromordinal(random.randint(1, 1000000))

    iterate_over_test_chunk([pyarrow.date32(), pyarrow.date32()],
                            [column_meta, column_meta],
                            date_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_binary_chunk():
    random.seed(datetime.datetime.now())
    column_meta = {
        "byteLength": "100",
        "logicalType": "BINARY",
        "precision": "0",
        "scale": "0",
        "charLength": "0"
    }

    def byte_array_generator():
        return bytearray(os.urandom(1000))

    iterate_over_test_chunk([pyarrow.binary(), pyarrow.binary()],
                            [column_meta, column_meta],
                            byte_array_generator)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_time_chunk():
    random.seed(datetime.datetime.now())
    column_meta_int64 = [
        {"logicalType": "TIME", "scale": "9"},
        {"logicalType": "TIME", "scale": "9"}
    ]

    column_meta_int32 = [
        {"logicalType": "TIME", "scale": "4"},
        {"logicalType": "TIME", "scale": "4"}
    ]

    def time_generator_int64():
        return random.randint(0, 86399999999999)

    def time_generator_int32():
        return random.randint(0, 863999999)

    def expected_data_transform_int64(data):
        milisec = data % (10**9)
        milisec //= 10**3
        data //= 10**9
        second = data % 60
        data //= 60
        minute = data % 60
        hour = data // 60
        return datetime.time(hour, minute, second, milisec)

    def expected_data_transform_int32(data):
        milisec = data % (10**4)
        milisec *= 10**2
        data //= 10**4
        second = data % 60
        data //= 60
        minute = data % 60
        hour = data // 60
        return datetime.time(hour, minute, second, milisec)

    iterate_over_test_chunk([pyarrow.int64(), pyarrow.int64()],
                            column_meta_int64, time_generator_int64, expected_data_transform_int64)

    iterate_over_test_chunk([pyarrow.int32(), pyarrow.int32()],
                            column_meta_int32, time_generator_int32, expected_data_transform_int32)


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_timestamp_ntz_chunk():
    random.seed(datetime.datetime.now())
    scale = random.randint(0, 9)
    column_meta = [
        {"logicalType": "TIMESTAMP_NTZ", "scale": str(scale)},
        {"logicalType": "TIMESTAMP_NTZ", "scale": str(scale)}
    ]
    data_type = pyarrow.struct([pyarrow.field('epoch', pyarrow.int64()),
                                pyarrow.field('fraction', pyarrow.int32())]) if scale > 7 else pyarrow.int64()

    def timestamp_ntz_generator(scale):
        epoch = random.randint(-621355968, 2534023007)
        frac = random.randint(0, 10**scale - 1) * (10**(9 - scale)) if scale > 7 else random.randint(0, 10**scale - 1)
        if scale > 7:
            return {'epoch': epoch, 'fraction': frac}
        else:
            epoch = str(epoch)
            frac = str(frac)
            ZEROFILL = '000000000'
            frac = ZEROFILL[:scale - len(frac)] + frac
            return int(epoch + frac) if scale else int(epoch)

    def expected_data_transform_ntz(_scale):
        def expected_data_transform_ntz_impl(data, scale=_scale):
            if scale > 7:
                frac = data['fraction']
                epoch = data['epoch']
                if epoch < 0:
                    epoch += 1
                    frac = 10**9 - frac
                frac = str(int(frac / 10**(9 - scale)))
                ZERO_FILL = '000000000'
                frac = ZERO_FILL[:scale - len(frac)] + frac
                data = int(str(epoch) + frac)

            microsec = str(data)
            if scale > 6:
                microsec = microsec[:-scale] + "." + microsec[-scale:-scale + 6]
            else:
                microsec = microsec[:-scale] + "." + microsec[-scale:] if scale else microsec

            if platform.system() == 'Windows':
                return datetime.datetime.utcfromtimestamp(0) + datetime.timedelta(seconds=(float(microsec)))
            else:
                return datetime.datetime.utcfromtimestamp(float(microsec))

        return expected_data_transform_ntz_impl

    iterate_over_test_chunk([data_type, data_type],
        column_meta, lambda: timestamp_ntz_generator(scale), expected_data_transform_ntz(scale))


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_timestamp_ltz_chunk():
    random.seed(datetime.datetime.now())
    scale = random.randint(0, 9)
    column_meta = [
        {"logicalType": "TIMESTAMP_LTZ", "scale": str(scale)},
        {"logicalType": "TIMESTAMP_LTZ", "scale": str(scale)}
    ]
    data_type = pyarrow.struct([pyarrow.field('epoch', pyarrow.int64()),
                                pyarrow.field('fraction', pyarrow.int32())]) if scale > 7 else pyarrow.int64()

    def timestamp_ltz_generator(scale):
        epoch = random.randint(-621355968, 2534023007)
        frac = random.randint(0, 10**scale - 1) * (10**(9 - scale)) if scale > 7 else random.randint(0, 10**scale - 1)
        if scale > 7:
            return {'epoch': epoch, 'fraction': frac}
        else:
            epoch = str(epoch)
            frac = str(frac)
            ZEROFILL = '000000000'
            frac = ZEROFILL[:scale - len(frac)] + frac
            return int(epoch + frac) if scale else int(epoch)

    def expected_data_transform_ltz(_scale):
        def expected_data_transform_ltz_impl(data, scale=_scale):
            tzinfo = get_timezone()   # can put a string parameter here in the future
            if scale > 7:
                frac = data['fraction']
                epoch = data['epoch']
                if epoch < 0:
                    epoch += 1
                    frac = 10**9 - frac
                frac = str(int(frac / 10**(9 - scale)))
                ZERO_FILL = '000000000'
                frac = ZERO_FILL[:scale - len(frac)] + frac
                data = int(str(epoch) + frac)

            microsec = str(data)
            if scale > 6:
                microsec = microsec[:-scale] + "." + microsec[-scale:-scale + 6]
            else:
                microsec = microsec[:-scale] + "." + microsec[-scale:] if scale else microsec

            if platform.system() == 'Windows':
                t0 = datetime.datetime.utcfromtimestamp(0) + datetime.timedelta(seconds=(float(microsec)))
                return pytz.utc.localize(t0, is_dst=False).astimezone(tzinfo)
            else:
                return datetime.datetime.fromtimestamp(float(microsec), tz=tzinfo)

        return expected_data_transform_ltz_impl

    iterate_over_test_chunk([data_type, data_type],
        column_meta, lambda: timestamp_ltz_generator(scale), expected_data_transform_ltz(scale))


@pytest.mark.skipif(
    not installed_pandas or no_arrow_iterator_ext,
    reason="arrow_iterator extension is not built, or pandas option is not installed.")
def test_iterate_over_timestamp_tz_chunk():
    random.seed(datetime.datetime.now())
    scale = random.randint(0, 9)
    column_meta = [
        {"byteLength": "16" if scale > 3 else "8", "logicalType": "TIMESTAMP_TZ", "scale": str(scale)},
        {"byteLength": "16" if scale > 3 else "8", "logicalType": "TIMESTAMP_TZ", "scale": str(scale)}
    ]

    type1 = pyarrow.struct([pyarrow.field('epoch', pyarrow.int64()),
              pyarrow.field('timezone', pyarrow.int32()),
              pyarrow.field('fraction', pyarrow.int32())])
    type2 = pyarrow.struct([pyarrow.field('epoch', pyarrow.int64()),
              pyarrow.field('timezone', pyarrow.int32())])
    data_type = type1 if scale > 3 else type2

    def timestamp_tz_generator(scale):
        epoch = random.randint(-621355968, 2534023007)
        frac = random.randint(0, 10**scale - 1) * (10**(9 - scale)) if scale > 3 else random.randint(0, 10**scale - 1)
        timezone = random.randint(1, 2879)
        if scale > 3:
            return {'epoch': epoch, 'timezone': timezone, 'fraction': frac}
        else:
            epoch = str(epoch)
            frac = str(frac)
            ZEROFILL = '000000000'
            frac = ZEROFILL[:scale - len(frac)] + frac
            return {'epoch': int(epoch + frac) if scale else int(epoch), 'timezone': timezone}

    def expected_data_transform_tz(_scale):
        def expected_data_transform_tz_impl(data, scale=_scale):
            timezone = data['timezone']
            tzinfo = _generate_tzinfo_from_tzoffset(timezone - 1440)
            epoch = data['epoch']
            if scale > 3:
                frac = data['fraction']
                if epoch < 0:
                    epoch += 1
                    frac = 10**9 - frac
                frac = str(int(frac / 10**(9 - scale)))
                ZERO_FILL = '000000000'
                frac = ZERO_FILL[:scale - len(frac)] + frac
                epoch = int(str(epoch) + frac)

            microsec = str(epoch)
            if scale > 6:
                microsec = microsec[:-scale] + "." + microsec[-scale:-scale + 6]
            else:
                microsec = microsec[:-scale] + "." + microsec[-scale:] if scale else microsec

            if platform.system() == 'Windows':
                t = datetime.datetime.utcfromtimestamp(0) + datetime.timedelta(seconds=(float(microsec)))
                if pytz.utc != tzinfo:
                    t += tzinfo.utcoffset(t)
                return t.replace(tzinfo=tzinfo)
            else:
                return datetime.datetime.fromtimestamp(float(microsec), tz=tzinfo)

        return expected_data_transform_tz_impl

    iterate_over_test_chunk([data_type, data_type],
        column_meta, lambda: timestamp_tz_generator(scale), expected_data_transform_tz(scale))


def iterate_over_test_chunk(pyarrow_type, column_meta, source_data_generator, expected_data_transformer=None):
    stream = BytesIO()

    assert len(pyarrow_type) == len(column_meta)

    column_size = len(pyarrow_type)
    batch_row_count = 10
    batch_count = 9

    fields = []
    for i in range(column_size):
        fields.append(pyarrow.field("column_{}".format(i), pyarrow_type[i], True, column_meta[i]))
    schema = pyarrow.schema(fields)

    expected_data = []
    writer = RecordBatchStreamWriter(stream, schema)

    for i in range(batch_count):
        column_arrays = []
        py_arrays = []
        for j in range(column_size):
            column_data = []
            not_none_cnt = 0
            while not_none_cnt == 0:
                column_data = []
                for _ in range(batch_row_count):
                    data = None if bool(random.getrandbits(1)) else source_data_generator()
                    if data is not None:
                        not_none_cnt += 1
                    column_data.append(data)
            column_arrays.append(column_data)
            py_arrays.append(pyarrow.array(column_data, type=pyarrow_type[j]))

        if expected_data_transformer:
            for i in range(len(column_arrays)):
                column_arrays[i] = [expected_data_transformer(_data) if _data is not None else None for _data in column_arrays[i]]
        expected_data.append(column_arrays)

        column_names = ["column_{}".format(i) for i in range(column_size)]
        rb = RecordBatch.from_arrays(py_arrays, column_names)
        writer.write_batch(rb)

    writer.close()

    # seek stream to begnning so that we can read from stream
    stream.seek(0)
    context = ArrowConverterContext()
    it = PyArrowIterator(None, stream, context, False, False)
    it.init(ROW_UNIT)

    count = 0
    while True:
        try:
            val = next(it)
            for i in range(column_size):
                batch_index = int(count / batch_row_count)
                assert val[i] == expected_data[batch_index][i][count - batch_row_count * batch_index]
            count += 1
        except StopIteration:
            assert count == (batch_count * batch_row_count)
            break


def get_timezone(timezone=None):
    """Gets, or uses the session timezone or use the local computer's timezone."""
    try:
        tz = 'UTC' if not timezone else timezone
        return pytz.timezone(tz)
    except pytz.exceptions.UnknownTimeZoneError:
        logger.warning('converting to tzinfo failed')
        if tzlocal is not None:
            return tzlocal.get_localzone()
        else:
            try:
                return datetime.datetime.timezone.utc
            except AttributeError:
                return pytz.timezone('UTC')