"""Ring buffers for multiprocessing.

Allows multiple child Python processes started via the multiprocessing module
to read from a shared ring buffer in the parent process. For each child, a
pointer is maintained for the purpose of reading. One pointer is maintained by
for the purpose of writing. Reads may be issued in blocking or non-blocking
mode. Writes are always in non-blocking mode and will raise an exception
if the buffer is full.
"""

import ctypes
import contextlib
import functools
import multiprocessing
import struct


class Error(Exception):
    pass


class DataTooLargeError(Error, ValueError):
    pass


class WaitingForReaderError(Error):
    pass


class WaitingForWriterError(Error):
    pass


class WriterFinishedError(Error):
    pass


class AlreadyClosedError(Error):
    pass


class MustCreatedReadersBeforeWritingError(Error):
    pass


class InternalLockingError(Error):
    pass


class Position:

    def __init__(self, slot_count):
        self.counter = 0
        self.slot_count = slot_count

    @property
    def index(self):
        return self.counter % self.slot_count

    @property
    def generation(self):
        return self.counter // self.slot_count


class Pointer:

    def __init__(self, slot_count, *, start=None):
        default = start if start is not None else 0
        self.counter = multiprocessing.RawValue(ctypes.c_longlong, default)
        self.position = Position(slot_count)

    def increment(self):
        self.counter.value += 1

    def get(self):
        # Avoid reallocating Position repeatedly.
        self.position.counter = self.counter.value
        return self.position

    def set(self, counter):
        self.counter.value = counter


class RingBuffer:
    """Circular buffer class accessible to multiple threads or child processes.

    All methods are thread safe. Multiple readers and writers are permitted.
    Before kicking off multiprocessing.Process instances, first allocate all
    of the writers you'll need with new_writer() and readers with new_reader().
    Pass the Pointer value returned by the new_reader() method to the
    multiprocessing.Process constructor along with the RingBuffer instance.
    Calling  new_writer() or new_reader() from a child multiprocessing.Process
    will not work.
    """

    def __init__(self, *, slot_bytes, slot_count):
        """Initializer.

        Args:
            slot_bytes: The maximum size of slots in the buffer.
            slot_count: How many slots should be in the buffer.
        """
        self.slot_count = slot_count
        self.array = SlotArray(slot_bytes=slot_bytes, slot_count=slot_count)
        self.lock = ReadersWriterLock()
        # Each reading process may modify its own Pointer while the read
        # lock is being held. Each reading process can also load the position
        # of the writer, but not load any other readers. Each reading process
        # can also load the value of the 'active' count.
        self.readers = []
        # The writer can load and store the Pointer of all the reader Pointers
        # or the writer Pointer while the write lock is held. It can also load
        # and store the value of the 'active' acount.
        self.writer = Pointer(self.slot_count)
        self.active = multiprocessing.RawValue(ctypes.c_uint, 0)

    def new_reader(self):
        """Returns a new unique reader into the buffer.

        This must only be called in the parent process. It must not be
        called in a child multiprocessing.Process. See class docstring. To
        enforce this policy, no readers may be allocated after the first
        write has occurred.
        """
        with self.lock.for_write():
            writer_position = self.writer.get()
            if writer_position.counter > 0:
                raise MustCreatedReadersBeforeWritingError

            reader = Pointer(self.slot_count, start=writer_position.counter)
            self.readers.append(reader)
            return reader

    def new_writer(self):
        """Must be called once by each writer before any reads occur.

        Should be paired with a single subsequent call to writer_done() to
        indicate that this writer has finished and will not write any more
        data into the ring.
        """
        with self.lock.for_write():
            self.active.value += 1

    def _has_write_conflict(self, position):
        index = position.index
        generation = position.generation
        for reader in self.readers:
            # This Position and the other Position both point at the same index
            # in the ring buffer, but they have different generation numbers.
            # This means the writer can't proceed until some readers have
            # sufficiently caught up.
            reader_position = reader.get()
            if (reader_position.index == index and
                    reader_position.generation < generation):
                return True

        return False

    def try_write(self, data):
        """Tries to write the next slot, but will not block.

        Once a successful write occurs, all pending blocking_read() calls
        will be woken up to consume the newly written slot.

        Args:
            data: Bytes to write in the next available slot. Must be
                less than or equal to slot_bytes in size.

        Raises:
            WaitingForReaderError: If all of the slots are full and we need
                to wait for readers to catch up before there will be
                sufficient room to write more data. This is a sign that
                the readers can't keep up with the writer. Consider calling
                force_reader_sync() if you need to force the readers to
                catch up, but beware that means they will miss data.
        """
        with self.lock.for_write():
            if not self.active.value:
                raise AlreadyClosedError

            position = self.writer.get()
            if self._has_write_conflict(position):
                raise WaitingForReaderError

            self.array[position.index] = data
            self.writer.increment()

    def _has_read_conflict(self, reader_position):
        writer_position = self.writer.get()
        return writer_position.counter <= reader_position.counter

    def _try_read_no_lock(self, reader):
        position = reader.get()
        if self._has_read_conflict(position):
            if not self.active.value:
                raise WriterFinishedError
            else:
                raise WaitingForWriterError

        data = self.array[position.index]
        reader.increment()
        return data

    def try_read(self, reader):
        """Tries to read the next slot for a reader, but will not block.

        Args:
            reader: Position previously returned by the call to new_reader().

        Returns:
            bytearray containing a copy of the data from the slot. This
            value is mutable an can be used to back ctypes objects, NumPy
            arrays, etc.

        Raises:
            WriterFinishedError: If the RingBuffer was closed before this
                read operation began.
            WaitingForWriterError: If the given reader has already consumed
                all the data in the ring buffer and would need to block in
                order to wait for new data to arrive.
        """
        with self.lock.for_read():
            return self._try_read_no_lock(reader)

    def blocking_read(self, reader):
        """Reads the next slot for a reader, blocking if it isn't filled yet.

        Args:
            reader: Position previously returned by the call to new_reader().

        Returns:
            bytearray containing a copy of the data from the slot. This
            value is mutable an can be used to back ctypes objects, NumPy
            arrays, etc.

        Raises:
            WriterFinishedError: If the RingBuffer was closed while waiting
                to read the next operation.
        """
        with self.lock.for_read():
            while True:
                try:
                    return self._try_read_no_lock(reader)
                except WaitingForWriterError:
                    self.lock.wait_for_write()

    def force_reader_sync(self):
        """Forces all readers to skip to the position of the writer."""
        with self.lock.for_write():
            writer_position = self.writer.get()

            for reader in self.readers:
                reader.set(writer_position.counter)

            for reader in self.readers:
                p = reader.get()

    def writer_done(self):
        """Called by the writer when no more data is expected to be written.

        Should be called once for every corresponding call to new_writer().
        Once all writers have called writer_done(), a WriterFinishedError
        exception will be raised by any blocking read calls or subsequent
        calls to read.
        """
        with self.lock.for_write():
            self.active.value -= 1


class SlotArray:
    """Fast array of indexable buffers backed by shared memory.

    Assumes locking happens elsewhere.
    """

    def __init__(self, *, slot_bytes, slot_count):
        """Initializer.

        Args:
            slot_bytes: How big each buffer in the array should be.
            slot_count: How many buffers should be in the array.
        """
        self.slot_bytes = slot_bytes
        self.slot_count = slot_count
        self.length_bytes = 4
        slot_type = ctypes.c_byte * (slot_bytes + self.length_bytes)
        self.array = multiprocessing.RawArray(slot_type, slot_count)

    def __getitem__(self, i):
        data = memoryview(self.array[i])
        (length,) = struct.unpack_from('>I', data, 0)

        start = self.length_bytes
        # This must create a copy because we want the writer to be able to
        # overwrite this slot as soon as the data has been retrieved by all
        # readers. But we also want the returned bytes to be mutable so that
        # the returned data can immediately back a ctypes record using the
        # from_buffer() method (instead of from_buffer_copy()).
        return bytearray(data[start:start + length])

    def __setitem__(self, i, data):
        data_view = memoryview(data).cast('@B')
        data_size = len(data_view)
        if data_size > self.slot_bytes:
            raise DataTooLargeError('%d bytes too big for slot' % data_size)

        # Avoid copying the input data! Do only a single copy into the slot.
        slot_view = memoryview(self.array[i]).cast('@B')
        struct.pack_into('>I', slot_view, 0, data_size)
        start = self.length_bytes
        slot_view[start:start + data_size] = data_view

    def __len__(self):
        return self.slot_count


class ReadersWriterLock:
    """Multiprocessing-compatible Readers/Writer lock.

    The algorithm:
    https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_a_condition_variable_and_a_mutex

    Background on the Kernel:
    https://www.kernel.org/doc/Documentation/memory-barriers.txt

    sem_wait on Linux uses NPTL, which uses futexes:
    https://github.com/torvalds/linux/blob/master/kernel/futex.c

    Notably, futexes use the smp_mb() memory fence, which is a general write
    barrier, meaning we can assume that all memory reads and writes before
    a barrier will complete before reads and writes after the barrier, even
    if the semaphore / futex isn't actively held.
    """

    def __init__(self):
        self.lock = multiprocessing.Lock()
        self.readers_condition = multiprocessing.Condition(self.lock)
        self.writer_condition = multiprocessing.Condition(self.lock)
        self.readers = multiprocessing.RawValue(ctypes.c_uint, 0)
        self.writer = multiprocessing.RawValue(ctypes.c_bool, False)

    def _acquire_reader_lock(self):
        with self.lock:
            while self.writer.value:
                self.readers_condition.wait()

            self.readers.value += 1

    def _release_reader_lock(self):
        with self.lock:
            self.readers.value -= 1

            if self.readers.value == 0:
                self.writer_condition.notify()

    @contextlib.contextmanager
    def for_read(self):
        """Acquire the lock for reading."""
        self._acquire_reader_lock()
        try:
            yield
        finally:
            self._release_reader_lock()

    def _acquire_writer_lock(self):
        with self.lock:
            while self.writer.value or self.readers.value > 0:
                self.writer_condition.wait()

            self.writer.value = True

    def _release_writer_lock(self):
        with self.lock:
            self.writer.value = False
            self.readers_condition.notify_all()
            self.writer_condition.notify()

    @contextlib.contextmanager
    def for_write(self):
        """Acquire the lock for writing reading."""
        self._acquire_writer_lock()
        try:
            yield
        finally:
            self._release_writer_lock()

    def wait_for_write(self):
        """Block until a writer has notified readers.

        Must be called while the read lock is already held. May return
        spuriously before the writer actually did something.
        """
        with self.lock:
            if self.readers.value == 0:
                raise InternalLockingError
            self.readers.value -= 1
            self.writer_condition.notify()
            self.readers_condition.wait()
            self.readers.value += 1