"""
To write tf_record into file. Here we use it for tensorboard's event writting.
The code was borrowed from https://github.com/TeamHG-Memex/tensorboard_logger
"""

import os
import copy
import io
import os.path
import re
import struct
try:
    import boto3
    S3_ENABLED = True
except ImportError:
    S3_ENABLED = False
try:
    from google.cloud import storage
    GCS_ENABLED = True
except ImportError:
    GCS_ENABLED = False

from .crc32c import crc32c


_VALID_OP_NAME_START = re.compile('^[A-Za-z0-9.]')
_VALID_OP_NAME_PART = re.compile('[A-Za-z0-9_.\\-/]+')

# Registry of writer factories by prefix backends.
#
# Currently supports "s3://" URLs for S3 based on boto,
# "gs://" URLs for Google Cloud Storage and falls
# back to local filesystem.
REGISTERED_FACTORIES = {}


def register_writer_factory(prefix, factory):
    if ':' in prefix:
        raise ValueError('prefix cannot contain a :')
    REGISTERED_FACTORIES[prefix] = factory


def directory_check(path):
    '''Initialize the directory for log files.'''
    try:
        prefix = path.split(':')[0]
        factory = REGISTERED_FACTORIES[prefix]
        return factory.directory_check(path)
    except KeyError:
        if not os.path.exists(path):
            os.makedirs(path)


def open_file(path):
    '''Open a writer for outputting event files.'''
    try:
        prefix = path.split(':')[0]
        factory = REGISTERED_FACTORIES[prefix]
        return factory.open(path)
    except KeyError:
        return open(path, 'wb')


class S3RecordWriter(object):
    """Writes tensorboard protocol buffer files to S3."""

    def __init__(self, path):
        if not S3_ENABLED:
            raise ImportError("boto3 must be installed for S3 support.")
        self.path = path
        self.buffer = io.BytesIO()

    def __del__(self):
        self.close()

    def bucket_and_path(self):
        path = self.path
        if path.startswith("s3://"):
            path = path[len("s3://"):]
        bp = path.split("/")
        bucket = bp[0]
        path = path[1 + len(bucket):]
        return bucket, path

    def write(self, val):
        self.buffer.write(val)

    def flush(self):
        s3 = boto3.client('s3', endpoint_url=os.environ.get('S3_ENDPOINT'))
        bucket, path = self.bucket_and_path()
        upload_buffer = copy.copy(self.buffer)
        upload_buffer.seek(0)
        s3.upload_fileobj(upload_buffer, bucket, path)

    def close(self):
        self.flush()


class S3RecordWriterFactory(object):
    """Factory for event protocol buffer files to S3."""

    def open(self, path):
        return S3RecordWriter(path)

    def directory_check(self, path):
        # S3 doesn't need directories created before files are added
        # so we can just skip this check
        pass


register_writer_factory("s3", S3RecordWriterFactory())


class GCSRecordWriter(object):
    """Writes tensorboard protocol buffer files to Google Cloud Storage."""

    def __init__(self, path):
        if not GCS_ENABLED:
            raise ImportError("`google-cloud-storage` must be installed in order to use "
                              "the 'gs://' protocol")

        self.path = path
        self.buffer = io.BytesIO()

        client = storage.Client()
        bucket_name, filepath = self.bucket_and_path()
        bucket = storage.Bucket(client, bucket_name)
        self.blob = storage.Blob(filepath, bucket)

    def __del__(self):
        self.close()

    def bucket_and_path(self):
        path = self.path
        if path.startswith("gs://"):
            path = path[len("gs://"):]
        bp = path.split("/")
        bucket = bp[0]
        path = path[1 + len(bucket):]
        return bucket, path

    def write(self, val):
        self.buffer.write(val)

    def flush(self):
        upload_buffer = copy.copy(self.buffer)
        upload_buffer.seek(0)

        self.blob.upload_from_string(upload_buffer.getvalue())

    def close(self):
        self.flush()


class GCSRecordWriterFactory(object):
    """Factory for event protocol buffer files to Google Cloud Storage."""

    def open(self, path):
        return GCSRecordWriter(path)

    def directory_check(self, path):
        # Google Cloud Storage doesn't need directories created before files
        # are added so we can just skip this check
        pass


register_writer_factory("gs", GCSRecordWriterFactory())


class RecordWriter(object):
    def __init__(self, path):
        self._name_to_tf_name = {}
        self._tf_names = set()
        self.path = path
        self._writer = None
        self._writer = open_file(path)

    def write(self, data):
        w = self._writer.write
        header = struct.pack('Q', len(data))
        w(header)
        w(struct.pack('I', masked_crc32c(header)))
        w(data)
        w(struct.pack('I', masked_crc32c(data)))

    def flush(self):
        self._writer.flush()

    def close(self):
        self._writer.close()


def masked_crc32c(data):
    x = u32(crc32c(data))
    return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8)


def u32(x):
    return x & 0xffffffff


def make_valid_tf_name(name):
    if not _VALID_OP_NAME_START.match(name):
        # Must make it valid somehow, but don't want to remove stuff
        name = '.' + name
    return '_'.join(_VALID_OP_NAME_PART.findall(name))