Sample caching support.

import struct
import pickle
import mmap
import io
import numpy as np
import lz4.frame

from vergeml import VergeMLError

class Cache:
    """Abstract base class for caches.

    def write(self, data, meta):
        """Write data and metadata to the cache.
        raise NotImplementedError

    def read(self, index, n_samples):
        """Read n_samples at index from the cache.
        raise NotImplementedError

class MemoryCache(Cache):
    """Cache samples in memory.

    def __init__(self):
        self.data = []

    def __len__(self):
        return len(self.data)

    def write(self, data, meta):
        self.data.append((data, meta))

    def read(self, index, n_samples):
        return self.data[index:index+n_samples]

class _CacheFileContent:

    def __init__(self):

        # An index of the positions of the stored data items.
        self.index = []

        # Sample metadata.
        self.meta = []

        # Info (Used to store data types)
        self.info = None

    def read(self, file, path):
        """Read the content index from file.
        pos, = struct.unpack('<Q', file.read(8))
        if pos == 0:
            raise VergeMLError("Invalid cache file: {}".format(path))
        self.index, self.meta, self.info = pickle.load(file)

    def write(self, file):
        """Write the content index to file and update the header.
        pos = file.tell()
        pickle.dump((self.index, self.meta, self.info), file)

        # update the header with the position of the content index.
        file.write(struct.pack('<Q', pos))

class FileCache(Cache):
    """Cache raw bytes in a mmapped file.

    def __init__(self, path, mode):
        assert mode in ("r", "w")

        self.path = path
        self.file = open(self.path, mode + "b")
        self.mmfile = None
        self.mode = mode
        self.cnt = _CacheFileContent()

        if mode == "r":
            # Read the last part of the file which contains the contents of the
            # cache.
            self.cnt.read(self.file, self.path)
            self.mmfile = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_READ)
            # The 8 bytes header contain the position of the content index.
            # We fill this header with zeroes and write the actual position
            # once all samples have been written to the cache
            self.file.write(struct.pack('<Q', 0))

    def __len__(self):
        return len(self.cnt.index)

    def write(self, data, meta):
        assert self.mode == "w"

        pos = self.file.tell()
        entry = (pos, pos + len(data))

        # write position and metadata of the data to the content index

    def read(self, index, n_samples):
        assert self.mode == "r"

        c_ix = self.cnt.index

        # get the absolute start and end adresses of the whole chunk
        abs_start, _ = c_ix[index]
        _, abs_end = c_ix[index+n_samples-1]

        # read the bytes and wrap in memory view to avoid copying
        chunk = memoryview(self.mmfile[abs_start:abs_end])

        res = []

        for i in range(n_samples):
            start, end = c_ix[index+i]

            # convert addresses to be relative to the chunk we read
            start = start - abs_start
            end = end - abs_start

            data = chunk[start:end]
            res.append((data, self.cnt.meta[index+i]))
        return res

    def close(self):
        """Close the cache file.

        When the cache file is being written to, this method will write
        the content index at the end of the file.
        if self.mode == "w":
            # Write the content index


# The three basic serialization methods:
# raw bytes, numpy format or python pickle.
_BYTES, _NUMPY, _PICKLE = range(3)

class SerializedFileCache(FileCache):
    """Cache serialized objects in a mmapped file.

    def __init__(self, path, mode, compress=True):
        """Create an optionally compressed serialized cache.
        super().__init__(path, mode)

        # we use info to store type information
        self.cnt.info = self.cnt.info or []
        self.compress = compress

    def _serialize_data(self, data):

        # Default to raw bytes
        type_ = _BYTES

        if isinstance(data, np.ndarray):
        # When the data is a numpy array, use the more compact native
        # numpy format.
            buf = io.BytesIO()
            np.save(buf, data)
            data = buf.getvalue()
            type_ = _NUMPY

        elif not isinstance(data, (bytearray, bytes)):
        # Everything else except byte data is serialized in pickle format.
            data = pickle.dumps(data)
            type_ = _PICKLE

        if self.compress:
        # Optional compression
            data = lz4.frame.compress(data)

        return type_, data

    def _deserialize(self, data, type_):

        if self.compress:
        # decompress the data if needed
            data = lz4.frame.decompress(data)

        if type_ == _NUMPY:
        # deserialize numpy arrays
            buf = io.BytesIO(data)
            data = np.load(buf)

        elif type_ == _PICKLE:
        # deserialize other python objects
            data = pickle.loads(data)

        # Otherwise we just return data as it is (bytes)

        return data

    def write(self, data, meta):

        if isinstance(data, tuple) and len(data) == 2:
            # write (x,y) pairs

            # serialize independent from each other
            type1, data1 = self._serialize_data(data[0])
            type2, data2 = self._serialize_data(data[1])

            pos = len(data1)
            data = io.BytesIO()

            # an entry wich consists of two items carries the position
            # of the second item in its header.
            data.write(struct.pack('<Q', pos))

            data = data.getvalue()

            # mark the entry as pair
            type_ = (type1, type2)

            type_, data = self._serialize_data(data)

        super().write(data, meta)

    def read(self, index, n_samples):

        # get the entries as raw bytes from the superclass implementation
        entries = super().read(index, n_samples)

        res = []
        for i, entry in enumerate(entries):
            data, meta = entry
            type_ = self.cnt.info[index+i]

            if isinstance(type_, tuple):
                # If the type is a pair (x,y), deserialize independently
                buf = io.BytesIO(data)

                # First, get the position of the second item from the header
                pos, = struct.unpack('<Q', buf.read(8))

                # Read the first and second item
                data1 = buf.read(pos)
                data2 = buf.read()

                # Then deserialize the independently.
                data1 = self._deserialize(data1, type_[0])
                data2 = self._deserialize(data2, type_[1])

                res.append(((data1, data2), meta))
                data = self._deserialize(data, type_)
                res.append((data, meta))

        return res