"""
Implements a custom protocol for sending and receiving
line delineated messages. For blocking sockets,
time-out is required to avoid DoS attacks when talking
to a misbehaving or malicious third party.

The benefit of this class is it makes communication
with the P2P network easy to code without having to
depend on threads and hence on mutexes (which are hard
to use correctly.)

In practice, a connection to a node on the P2P network
would be done using the default options of this class
and the connection would periodically be polled for
replies. The processing of replies would automatically
break once the socket indicated it would block and
to prevent a malicious node from sending replies as
fast as it could - there would be a max message limit
per check period.

Quirks:
* send_line will block until the entire line has been sent even if the socket
  has been set to non-blocking to make things easier. If you need a non-blocking
  way to send a line: use send(). Note that you will have to check for the
  number of bytes sent and resend if needed just like the real send function.
* connect has the same behaviour as above to make things simpler (so will block
  regardless of whether socket is in non-blocking mode or not.) If you want to
  bypass this behaviour you can always connect the socket outside this class
  and then pass it to set_socket.

Otherwise, all functions in this class behave how you would expect them to
(depending on whether you're using non-blocking mode or blocking mode.) It's
assumed that all blocking operations have a timeout by default. This can't be
disabled.

Todo: test various functions under connection exit.
Timeouts are needed for non-blocking too under conditions where you attempt to
send all / recv all.
"""

import errno
import platform
import socket
import ssl
import sys
import time

from pyp2p.lib import get_lan_ip, parse_exception, log_exception
from pyp2p.lib import encode_str

error_log_path = "error.log"


class Sock:
    def __init__(self, addr=None, port=None, blocking=0, timeout=5,
                 interface="default", use_ssl=0, debug=0):
        self.nonce = None
        self.nonce_buf = u""
        self.reply_filter = None
        self.buf = b""
        self.max_buf = 1024 * 1024  # 1 MB.
        self.max_chunks = 1024  # Prevents spamming of multiple short messages.
        self.chunk_size = 1024 * 4
        self.replies = []
        self.blocking = blocking
        self.timeout = timeout
        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        # self.s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self.use_ssl = use_ssl
        self.alive = time.time()
        self.unl = None
        if self.use_ssl:
            self.s = ssl.wrap_socket(self.s)

        self.connected = 0
        self.interface = interface
        self.delimiter = b"\r\n"
        self.debug = debug

        # Set keep alive.
        # self.set_keep_alive(self.s)

        # Connect socket.
        if addr is not None and port is not None:
            # Set a timeout for blocking operations so they don't DoS program.
            # Disabled after connect if non-blocking is set.
            # (Connect is so far always blocking regardless of blocking mode.)
            self.s.settimeout(5)

            self.connect(addr, port)
        else:
            self.set_blocking(self.blocking, self.timeout)

    def debug_print(self, msg):
        if self.debug:
            msg = "> " + str(msg)
            print(msg)

    def set_keep_alive(self, sock, after_idle_sec=5, interval_sec=60,
                       max_fails=5):
        """
        This function instructs the TCP socket to send a heart beat every n
        seconds to detect dead connections. It's the TCP equivalent of the
        IRC ping-pong protocol and allows for better cleanup / detection
        of dead TCP connections.

        It activates after 1 second (after_idle_sec) of idleness, then sends
        a keepalive ping once every 3 seconds(interval_sec), and closes the
        connection after 5 failed ping (max_fails), or 15 seconds
        """

        # OSX
        if platform.system() == "Darwin":
            # scraped from /usr/include, not exported by python's socket module
            TCP_KEEPALIVE = 0x10
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)

        if platform.system() == "Windows":
            sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 10000, 3000))

        if platform.system() == "Linux":
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE,
                            after_idle_sec)
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL,
                            interval_sec)
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, max_fails)

    def set_blocking(self, blocking, timeout=5):
        if self.s is None:
            return

        # Update blocking mode.
        self.s.setblocking(blocking)

        # Adjust timeout if needed.
        if blocking:
            if timeout is not None:
                self.s.settimeout(timeout)

        # Update blocking status.
        self.timeout = timeout
        self.blocking = blocking

    def set_sock(self, s):
        self.close()  # Close old socket.
        self.s = s
        self.set_blocking(self.blocking, self.timeout)
        # self.s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

        # Set keep alive.
        # self.set_keep_alive(self.s)

        # Save addr + port.
        try:
            addr, port = self.s.getpeername()
            self.addr = addr
            self.port = port
            self.connected = 1
        except:
            self.connected = 0

    def reconnect(self):
        if not self.connected:
            if self.addr is not None and self.port is not None:
                try:
                    return self.connect(self.addr, self.port)
                except:
                    self.connected = 0

    # Blocking (regardless of socket mode.)
    def connect(self, addr, port):
        # Save addr and port so socket can be reconnected.
        self.addr = addr
        self.port = port

        # No socket detected.
        if self.s is None:
            self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            if self.use_ssl:
                self.s = ssl.wrap_socket(self.s)

        # Make connection from custom interface.
        if self.interface != "default":
            try:
                # Todo: fix this to use static ips from Net
                src_ip = get_lan_ip(self.interface)
                self.s.bind((src_ip, 0))
            except socket.error as e:
                if e.errno != 98:
                    raise e

        try:
            self.s.connect((addr, int(port)))
            self.connected = 1
            self.set_blocking(self.blocking, self.timeout)
        except Exception as e:
            self.debug_print("Connect failed")
            error = parse_exception(e)
            self.debug_print(error)
            log_exception(error_log_path, error)
            raise socket.error("Socket connect failed.")

    def close(self):
        self.connected = 0

        # Attempt graceful shutdown.
        try:
            try:
                self.s.shutdown(1)
            except:
                pass
            self.s.close()
        except:
            pass

        self.s = None


    def parse_buf(self, encoding="unicode"):
        """
        Since TCP is a stream-orientated protocol, responses aren't guaranteed
        to be complete when they arrive. The buffer stores all the data and
        this function splits the data into replies based on the new line
        delimiter.
        """
        buf_len = len(self.buf)
        replies = []
        reply = b""
        chop = 0
        skip = 0
        i = 0
        buf_len = len(self.buf)
        for i in range(0, buf_len):
            ch = self.buf[i:i + 1]
            if skip:
                skip -= 1
                i += 1
                continue

            nxt = i + 1
            if nxt < buf_len:
                if ch == b"\r" and self.buf[nxt:nxt + 1] == b"\n":

                    # Append new reply.
                    if reply != b"":
                        if encoding == "unicode":
                            replies.append(encode_str(reply, encoding))
                        else:
                            replies.append(reply)
                        reply = b""

                    # Truncate the whole buf if chop is out of bounds.
                    chop = nxt + 1
                    skip = 1
                    i += 1
                    continue

            reply += ch
            i += 1

        # Truncate buf.
        if chop:
            self.buf = self.buf[chop:]

        return replies

    # Blocking or non-blocking.
    def get_chunks(self, fixed_limit=None, encoding="unicode"):
        """
        This is the function which handles retrieving new data chunks. It's
        main logic is avoiding a recv call blocking forever and halting
        the program flow. To do this, it manages errors and keeps an eye
        on the buffer to avoid overflows and DoS attacks.

        http://stackoverflow.com/questions/16745409/what-does-pythons-socket-recv-return-for-non-blocking-sockets-if-no-data-is-r
        http://stackoverflow.com/questions/3187565/select-and-ssl-in-python
        """

        # Socket is disconnected.
        if not self.connected:
            return

        # Recv chunks until network buffer is empty.
        repeat = 1
        wait = 0.2
        chunk_no = 0
        max_buf = self.max_buf
        max_chunks = self.max_chunks
        if fixed_limit is not None:
            max_buf = fixed_limit
            max_chunks = fixed_limit

        while repeat:
            chunk_size = self.chunk_size
            while True:
                # Don't exceed buffer size.
                buf_len = len(self.buf)
                if buf_len >= max_buf:
                    break
                remaining = max_buf - buf_len
                if remaining < chunk_size:
                    chunk_size = remaining

                # Don't allow non-blocking sockets to be
                # DoSed by multiple small replies.
                if chunk_no >= max_chunks and not self.blocking:
                    break

                try:
                    chunk = self.s.recv(chunk_size)
                except socket.timeout as e:
                    self.debug_print("Get chunks timed out.")
                    self.debug_print(e)

                    # Timeout on blocking sockets.
                    err = e.args[0]
                    self.debug_print(err)
                    if err == "timed out":
                        repeat = 0
                        break
                except ssl.SSLError as e:
                    # Will block on non-blocking SSL sockets.
                    if e.errno == ssl.SSL_ERROR_WANT_READ:
                        self.debug_print("SSL_ERROR_WANT_READ")
                        break
                    else:
                        self.debug_print("Get chunks ssl error")
                        self.close()
                        return
                except socket.error as e:
                    # Will block on nonblocking non-SSL sockets.
                    err = e.args[0]
                    if err == errno.EAGAIN or err == errno.EWOULDBLOCK:
                        break
                    else:
                        # Connection closed or other problem.
                        self.debug_print("get chunks other closing")
                        self.close()
                        return
                else:
                    if chunk == b"":
                        self.close()
                        return

                    # Avoid decoding errors.
                    self.buf += chunk

                    # Otherwise the loop will be endless.
                    if self.blocking:
                        break

                    # Used to avoid DoS of small packets.
                    chunk_no += 1

            # Repeat is already set -- manual skip.
            if not repeat:
                break
            else:
                repeat = 0

            # Block until there's a full reply or there's a timeout.
            if self.blocking:
                if fixed_limit is None:
                    # Partial response.
                    if self.delimiter not in self.buf:
                        repeat = 1
                        time.sleep(wait)

    def reply_callback(self, callback):
        self.reply_callback = callback

    # Called to check for replies and update buffers.
    def update(self):
        self.get_chunks()
        self.replies += self.parse_buf()

        # Execute callbacks on replies.
        if self.reply_filter is not None:
            replies = []
            for reply in self.replies:
                if not self.reply_filter(reply):
                    replies.append(u"")
                else:
                    replies.append(reply)

            self.replies = replies

    # Blocking or non-blocking.
    def send(self, msg, send_all=0, timeout=5, encoding="ascii"):
        # Update timeout.
        if timeout != self.timeout and self.blocking:
            self.set_blocking(self.blocking, timeout)

        try:
            # Convert to bytes Python 2 & 3
            # The caller should ensure correct encoding.
            if type(msg) == type(u""):
                msg = encode_str(msg, "ascii")

            # Work out stop time.
            if send_all:
                future = time.time() + (timeout or self.timeout)
            else:
                future = 0

            repeat = 1
            total_sent = 0
            msg_len = len(msg)
            while repeat:
                repeat = 0
                while True:
                    # Attempt to send all.
                    # This won't work if the network buffer is already full.
                    try:
                        bytes_sent = self.s.send(
                                msg[total_sent:self.chunk_size])
                    except socket.timeout as e:
                        err = e.args[0]
                        if err == "timed out":
                            return 0
                    except socket.error as e:
                        err = e.args[0]
                        if err == errno.EAGAIN or err == errno.EWOULDBLOCK:
                            break
                        else:
                            # Connection closed or other problem.
                            self.debug_print("Con send closing other")
                            self.close()
                            return 0

                    # Connection broken.
                    if not bytes_sent or bytes_sent is None:
                        self.close()
                        return 0

                    # How much has been sent?
                    total_sent += bytes_sent

                    # Avoid looping forever.
                    if self.blocking and not send_all:
                        break

                    # Everything sent.
                    if total_sent >= msg_len:
                        break

                    # Don't block.
                    if not send_all:
                        break

                    # Avoid 100% CPU.
                    time.sleep(0.001)

                # Avoid looping forever.
                if send_all:
                    if time.time() >= future:
                        repeat = 0
                        break

                # Send the rest if blocking:
                if total_sent < msg_len and send_all:
                    repeat = 1

            return total_sent
        except Exception as e:
            self.debug_print("Con send: " + str(e))
            error = parse_exception(e)
            log_exception(error_log_path, error)
            self.close()

    # Blocking or non-blocking.
    def recv(self, n, encoding="unicode", timeout=5):
        # Sanity checking.
        assert n

        # Update timeout.
        if timeout != self.timeout and self.blocking:
            self.set_blocking(self.blocking, timeout)

        try:
            # Get data.
            self.get_chunks(n, encoding=encoding)

            # Return the current buffer.
            ret = self.buf

            # Reset the old buffer.
            self.buf = b""

            # Return results.
            if encoding == "unicode":
                ret = encode_str(ret, encoding)

            return ret
        except Exception as e:
            self.debug_print("Recv closign e" + str(e))
            error = parse_exception(e)
            log_exception(error_log_path, error)
            self.close()
            if encoding == "unicode":
                return u""
            else:
                return b""

    # Sends a new message delimitered by a new line.
    # Blocking: blocks until entire line is sent for simplicity.
    def send_line(self, msg, timeout=5):
        # Sanity checking.
        assert (len(msg))

        # Not connected.
        if not self.connected:
            return 0

        # Update timeout.
        if timeout != self.timeout and self.blocking:
            self.set_blocking(self.blocking, timeout)

        try:
            # Convert to bytes Python 2 & 3
            if type(msg) == type(u""):
                msg = encode_str(msg, "ascii")

            # Convert delimiter to bytes.
            msg += self.delimiter

            """
            The inclusion of the send_all flag makes this function behave like
            a blocking socket for the purposes of sending a full line even if
            the socket is non-blocking. It's assumed that lines will be small
            and if the network buffer is full this code won't end up as a
            bottleneck. (Otherwise you would have to check the number of bytes
            returned every time you sent a line which is quite annoying.)
            """
            ret = self.send(msg, send_all=1, timeout=timeout)

            return ret
        except Exception as e:
            self.debug_print("Send line closing" + str(e))
            error = parse_exception(e)
            log_exception(error_log_path, error)
            self.close()
            return 0

    # Receives a new message delimited by a new line.
    # Blocking or non-blocking.
    def recv_line(self, timeout=5):
        # Socket is disconnected.
        if not self.connected:
            return u""

        # Update timeout.
        if timeout != self.timeout and self.blocking:
            self.set_blocking(self.blocking, timeout)

        # Return existing reply.
        if len(self.replies):
            temp = self.replies[0]
            self.replies = self.replies[1:]
            return temp

        try:
            future = time.time() + (timeout or self.timeout)
            while True:
                self.update()

                # Socket is disconnected.
                if not self.connected:
                    return u""

                # Non-blocking.
                if not ((not len(self.replies) or len(
                        self.buf) >= self.max_buf) and self.blocking):
                    break

                # Timeout elapsed.
                if time.time() >= future and self.blocking:
                    break

                # Avoid 100% CPU.
                time.sleep(0.002)

            if len(self.replies):
                temp = self.replies[0]
                self.replies = self.replies[1:]
                return temp

            return u""
        except Exception as e:
            self.debug_print("recv line error")
            error = parse_exception(e)
            self.debug_print(error)
            log_exception(error_log_path, error)

    """
    These functions here make the class behave like a list. The
    list is a collection of replies received from the socket.
    Every iteration also has the bonus of checking for any
    new replies so it is very easy, for example to do:
    for replies in sock:
        To process replies without handling networking boilerplate.
    """

    def __len__(self):
        self.update()
        return len(self.replies)

    def __getitem__(self, key):
        self.update()
        return self.replies[key]

    def __setitem__(self, key, value):
        self.update()
        self.replies[key] = value

    def __delitem__(self, key):
        self.update()
        del self.replies[key]

    def pop_reply(self):
        # Get replies.
        replies = []
        for reply in self.replies:
            replies.append(reply)

        if len(replies):
            # Put replies back in the queue.
            self.replies = replies[1:]

            # Return the first reply.
            return replies[0]
        else:
            return None

    def __iter__(self):
        try:
            # Get replies.
            self.update()

            # Return replies.
            return iter(self.replies)
        finally:
            # Clear old replies.
            self.replies = []

    def __reversed__(self):
        return self.__iter__()


if __name__ == "__main__":
    """
    s = Sock("158.69.201.105", 8540)

    exit()
    s.send_line("SOURCE TCP")


    while 1:
        for reply in s:
            print(reply)

        time.sleep(0.5)


    # print(s.recv_line())
    # print("yes")
    """