# ==================================================================================================
# Copyright 2014 Twitter, Inc.
# --------------------------------------------------------------------------------------------------
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this work except in compliance with the License.
# You may obtain a copy of the License in the LICENSE file, or at:
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==================================================================================================


from collections import defaultdict
from random import random
from threading import Thread

import logging
import os
import hexdump
import signal
import socket
import struct
import sys

from .client_message import ClientMessage, Request
from .network import BadPacket, get_ip, get_ip_packet, SnifferBase
from .server_message import Reply, ServerMessage, WatchEvent
from .zookeeper import DeserializationError, OpCodes
from .util import StringTooLong, to_bytes

from scapy.config import conf as scapy_conf
scapy_conf.logLevel = logging.ERROR  # shush scapy

from scapy.sendrecv import sniff
from six.moves import intern
from twitter.common import log


DEFAULT_PORT = 2181
FOUR_LETTER_WORDS = (
  to_bytes('conf'),
  to_bytes('cons'),
  to_bytes('crst'),
  to_bytes('dump'),
  to_bytes('envi'),
  to_bytes('ruok'),
  to_bytes('srst'),
  to_bytes('srvr'),
  to_bytes('stat'),
  to_bytes('wchs'),
  to_bytes('wchc'),
  to_bytes('wchp'),
  to_bytes('mntr'),
  to_bytes('kill'),  # deprecated
  to_bytes('reqs'),  # deprecated
)


class SnifferConfig(object):
  def __init__(self,
      iface="eth0",
      writes_only=False,
      debug=False):
    """
      if client_port is 0 we sniff all clients
      if zookeeper_port is changed later on you must call update_filter()
    """
    self.iface = iface
    self.writes_only = writes_only
    self.debug = debug
    self.client_port = 0
    self.track_replies = False
    self.max_queued_requests = 10000
    self.zookeeper_port = DEFAULT_PORT
    self.excluded_opcodes = set()
    self.is_loopback = iface in ["lo", "lo0"]
    self.read_timeout_ms = 0
    self.dump_bad_packet = False
    self.sampling = 1.0  # percentage of packets to inspect [0, 1]

    # These are set after initialization, and require `update_filter` to be called
    self.included_ips = []
    self.excluded_ips = []

    self.update_filter()
    self.exclude_pings()

  def update_filter(self):
    self.filter = "port %d" % (self.zookeeper_port)

    assert not (self.included_ips and self.excluded_ips)

    if self.excluded_ips:
      self.filter +=  " and host not " + " and host not ".join(self.excluded_ips)
    elif self.included_ips:
      self.filter += " and (host " + " or host ".join(self.included_ips) + ")"

  def include_pings(self):
    self.update_exclusion_list(OpCodes.PING, False)

  def exclude_pings(self):
    self.update_exclusion_list(OpCodes.PING, True)

  def excluded(self, opcode):
    return opcode in self.excluded_opcodes

  def update_exclusion_list(self, opcode, exclude):
    if exclude:
      self.excluded_opcodes.add(opcode)
    else:
      try:
        self.excluded_opcodes.remove(opcode)
      except KeyError:
        pass

  def __str__(self):
    return """
***sniffer config ***
iface = %s
writes_only = %s
filter = %s
zookeeper_port = %d
is_loopback = %s
read_timeout_ms = %d
debug = %s
""" % (self.iface,
          str((self.writes_only)).lower(),
          self.filter,
          self.zookeeper_port,
          str(self.is_loopback),
          self.read_timeout_ms,
          str(self.debug).lower())


class Sniffer(SnifferBase):
  class RegistrationError(Exception): pass

  def __init__(self,
               config,
               request_handler=None,
               reply_handler=None,
               event_handler=None,
               error_to_stderr=False):
    """
    This sniffer will intercept:
     - client requests
     - server replies
     - server events (i.e.: connection state change or, most of the times, watches)
    Hence handlers for each.
    """
    super(Sniffer, self).__init__()

    self._error_to_stderr = error_to_stderr
    self._packet_size = 65535
    self._request_handlers = []
    self._reply_handlers = []
    self._event_handlers = []
    self._requests_xids = defaultdict(dict)  # if tracking replies, keep a tab for seen reqs
    self._four_letter_mode = {}              # key: client addr, val: four letter
    self._wants_stop = False

    self.config = config

    self.add_request_handler(request_handler)
    self.add_reply_handler(reply_handler)
    self.add_event_handler(event_handler)

    self.setDaemon(True)

  def stop(self):
    self._wants_stop = True

  def add_request_handler(self, handler):
    self._add_handler(self._request_handlers, handler)

  def add_reply_handler(self, handler):
    self._add_handler(self._reply_handlers, handler)

  def add_event_handler(self, handler):
    self._add_handler(self._event_handlers, handler)

  def _add_handler(self, handlers, handler):
    if handler is None:
      return

    if handler in handlers:
      raise self.RegistrationError("handler %s has already been added" % (handler))

    handlers.append(handler)

  def wants_stop(self, *args, **kwargs):  # pragma: no cover
    return self._wants_stop

  def run(self):
    try:
      log.info("Setting filter: %s", self.config.filter)
      if self.config.iface == "any":  # pragma: no cover
        sniff(
          filter=self.config.filter,
          store=0,
          prn=self.handle_packet,
          stop_filter=self.wants_stop
        )
      else:
        sniff(
          filter=self.config.filter,
          store=0,
          prn=self.handle_packet,
          iface=self.config.iface,
          stop_filter=self.wants_stop
        )
    except socket.error as ex:
      if self._error_to_stderr:
        sys.stderr.write("Error: %s, device: %s\n" % (ex, self.config.iface))
      else:
        log.error("Error: %s, device: %s", ex, self.config.iface)
    finally:
      log.info("The sniff loop exited")
      os.kill(os.getpid(), signal.SIGINT)

  def handle_packet(self, packet):
    sampling = self.config.sampling
    if sampling < 1.0 and random() > sampling:
      return

    try:
      message = self.message_from_packet(packet)
      self.handle_message(message)
    except (BadPacket, StringTooLong, DeserializationError, struct.error) as ex:
      if self.config.dump_bad_packet:
        print("got: %s" % str(ex))
        hexdump.hexdump(packet.load)
        sys.stdout.flush()

  def handle_message(self, message):
    if message and not self.config.excluded(message.opcode):
      for h in self._handlers_for(message):
        h(message)

  def _handlers_for(self, message):
    if isinstance(message, Request):
      if self.config.writes_only and not message.is_write:
        raise BadPacket("Not a write packet")
      return self._request_handlers
    elif isinstance(message, Reply):
      return self._reply_handlers
    elif isinstance(message, WatchEvent):
      return self._event_handlers

    raise BadPacket("No handlers for: %s" % (message))

  def message_from_packet(self, packet):
    """
    :returns: Returns an instance of ClientMessage or ServerMessage (or a subclass)
    :raises:
      :exc:`BadPacket` if the packet is for a client we are not tracking
      :exc:`DeserializationError` if deserialization failed
      :exc:`struct.error` if deserialization failed
    """
    client_port = self.config.client_port
    zk_port = self.config.zookeeper_port
    ip_p = get_ip_packet(packet.load, client_port, zk_port, self.config.is_loopback)

    if 0 == len(ip_p.data.data):
      return None

    if ip_p.data.dport == zk_port:
      data = ip_p.data.data
      src = intern("%s:%s" % (get_ip(ip_p, ip_p.src), ip_p.data.sport))
      dst = intern("%s:%s" % (get_ip(ip_p, ip_p.dst), ip_p.data.dport))
      client, server = src, dst
      if data.startswith(FOUR_LETTER_WORDS):
        self._set_four_letter_mode(client, data[0:4])
        raise BadPacket("Four letter request %s" % data[0:4])
      client_message = ClientMessage.from_payload(data, client, server)
      client_message.timestamp = packet.time
      self._track_client_message(client_message)
      return client_message

    if ip_p.data.sport == zk_port:
      data = ip_p.data.data
      src = intern("%s:%s" % (get_ip(ip_p, ip_p.src), ip_p.data.sport))
      dst = intern("%s:%s" % (get_ip(ip_p, ip_p.dst), ip_p.data.dport))
      client, server = dst, src
      four_letter = self._get_four_letter_mode(client)
      if four_letter:
        self._set_four_letter_mode(client, None)
        raise BadPacket("Four letter response %s" % four_letter)
      requests_xids = self._requests_xids.get(client, {})
      server_message = ServerMessage.from_payload(data, client, server, requests_xids)
      server_message.timestamp = packet.time
      return server_message

    raise BadPacket("Packet to the wrong port?")

  def _track_client_message(self, request):
    """
    Any request that is not a ping or a close should be tracked
    """
    if self.config.track_replies and not request.is_ping and not request.is_close:
      requests_xids = self._requests_xids[request.client]
      if len(requests_xids) > self.config.max_queued_requests:  # pragma: no cover
        # TODO: logging the counts of each type of pkts in the queue when this happens
        #       could be useful.
        if self._error_to_stderr:
          sys.stderr.write("Too many queued requests, replies for %s will be lost\n" %
                           request.client)
        else:
          log.error("Too many queued requests, replies for %s will be lost", request.client)
        return

      requests_xids[request.xid] = request.opcode

  def _get_four_letter_mode(self, client):
    return self._four_letter_mode.get(client)

  def _set_four_letter_mode(self, client, four_letter):
    if four_letter:
        self._four_letter_mode[client] = four_letter
    else:
        del self._four_letter_mode[client]