import socket
import threading
import logging
import platform
import traceflow
import struct


class socket_handler:
    def __init__(self, daddr):
        try:
            self.raw_sock = socket.socket(
                socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW
            )
        except PermissionError as e:
            print(e)
            print("Please run as root!")
            exit(1)
        self.ip_daddr = daddr
        os_release = platform.system()
        if os_release == "Darwin":
            # Boned. TODO: Work on fixing this.
            print("Detected Mac OS - Cannot support writing of raw IP packets, exiting")
            exit(1)
        if os_release.endswith("BSD"):
            # BSD - Need to explicit set IP_HDRINCL.
            # BSD - Need to explicitly calculate IP total length
            self.raw_sock.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
            logging.debug("Detected a BSD")
        if os_release == "Linux":
            # Linux - No need to set IP_HDRINCL,as setting SOCK_RAW auto sets this. However should be explicit in settings.
            self.raw_sock.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
            logging.debug("Detected Linux")
        if os_release == "Windows":
            # No idea - No ability to test. Maybe abort?
            # TODO: Find testers?
            logging.debug("Detected NT")
            print("Untested on Windows - Exiting")
            exit(1)

    def send_ipv4(self, packet: bytes) -> int:
        """
        send_ipv4 is a thin wrapper around socket.sendto()
        :param packet: bytes object containing an IPv4 header, and encap'd proto packet (ie: udp/tcp)
        :return: int: the bits put on the wire
        """
        bits = self.raw_sock.sendto(packet, (self.ip_daddr, 0))
        return bits

    @staticmethod
    def get_egress_ip(daddr: bytes) -> str:
        """
        __get_egress_ip is an internal method to find out our egress address for a given destination
        TODO: Fixup for IPv6 compat.

        :param daddr: destination address in binary form
        :return: egress_ip_address, a string/quad dotted notation IPv4 address

        """
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect((daddr, 1))  # connect() for UDP doesn't send packets
        egress_ip_address = s.getsockname()[0]
        logging.debug("Picked %s as egress IP" % egress_ip_address)
        s.shutdown(0)
        s.close()
        return egress_ip_address


class socket_listener:
    def __init__(self, ip_daddr):
        # We're only interested in ICMP, so happy to have this hard coded.
        try:
            self.icmp_listener = socket.socket(
                socket.AF_INET, socket.SOCK_RAW, socket.getprotobyname("icmp")
            )
        except PermissionError as e:
            print(e)
            print("Please run as root!")
            exit(1)
        # TODO: Test Timestamps correctly
        try:
            SO_TIMESTAMPNS = 35
            self.icmp_listener.setsockopt(socket.SOL_SOCKET, SO_TIMESTAMPNS, 1)
        except OSError as e:
            logging.debug("Timestamps not available, continuing without them for now")
        self.ip_daddr = ip_daddr
        self.mutex = threading.Lock()
        logging.debug("Starting")
        self.icmp_packets = dict()
        t = threading.Thread(target=self.listener)
        t.setDaemon(True)
        t.start()

    def listener(self):
        """thread worker function"""
        logging.debug("Listening for ICMP...")
        while True:
            icmp_packet, curr_addr = self.icmp_listener.recvfrom(512)
            # Decode the IPv4 packet around the ICMP message
            i = traceflow.packet_decode.decode_ipv4_header(icmp_packet)
            # Decode the actual ICMP message inside the IPv4 packet
            icmp_packet_ret = traceflow.packet_decode.decode_icmp(i["payload"])
            # And decode the returning IPv4 packet which is the payload inside of the ICMP message
            ip_id = traceflow.packet_decode.decode_ipv4_header(
                icmp_packet_ret["payload"]
            )["ip_id"]
            # Did we get a TTL Expired (11)?
            if icmp_packet_ret["type"] == 11:
                logging.debug(
                    "Got TTL Expired from %s with ip_id %s" % (curr_addr[0], ip_id)
                )
                self.mutex.acquire()
                self.icmp_packets[ip_id] = i
                self.mutex.release()
            ## TODO: Correctly implement a "stop" here
            if curr_addr[0] == self.ip_daddr:
                self.mutex.acquire()
                self.icmp_packets[ip_id] = i
                self.mutex.release()

    def get_packet_by_ipid(self, ipid: int) -> dict:
        """
        get_packet_by_ipid will take in a specific ip.id and find the corresponding packet

        :param ipid: the ip.id of the packet in question
        :return: dict() which contains the corresponding IP packet
        """
        self.mutex.acquire()
        for packet in self.icmp_packets.keys():
            icmp_packet = traceflow.packet_decode.decode_icmp(
                self.icmp_packets[packet]["payload"]
            )
            ipv4_packet = traceflow.packet_decode.decode_ipv4_header(
                icmp_packet["payload"]
            )
            if ipid == ipv4_packet["ip_id"]:
                return icmp_packet
        self.mutex.release()
        return None

    def get_all_packets(self) -> list:
        """
        get_all_packets returns all currently captures packets.

        :return: list() of dicts()
        """
        self.mutex.acquire()
        i = self.icmp_packets
        self.mutex.release()
        return i

    def get_packets_by_pathid(self, path_id: int) -> list:
        """
        get_packets_by_runid depends on the fact that we intent to manually construct the IPID for each packet, so the
        top 8 bits correspond to a "run".

        :param run_id: an int which is 8 bits in size and corresponds to a path.
        :return: list of packets
        """
        packets = list()
        self.mutex.acquire()
        for packet in self.icmp_packets.keys():
            icmp_packet = traceflow.packet_decode.decode_icmp(
                self.icmp_packets[packet]["payload"]
            )
            ipv4_packet = traceflow.packet_decode.decode_ipv4_header(
                icmp_packet["payload"]
            )
            b = ipv4_packet["ip_id"].to_bytes(2, byteorder="big")
            (run, ttl) = struct.unpack("!BB", b)
            if run == path_id:
                packets.append(self.icmp_packets[packet])
        self.mutex.release()
        return packets