import itertools
import json
import logging
import pickle
# noinspection PyPep8Naming
from typing import Iterator, KeysView, List, Optional, Text

from rasa.core.actions.action import ACTION_LISTEN_NAME
from rasa.core.broker import EventChannel
from rasa.core.domain import Domain
from rasa.core.trackers import (
    ActionExecuted, DialogueStateTracker, EventVerbosity)
from rasa.core.utils import class_from_module_path

logger = logging.getLogger(__name__)


class TrackerStore(object):
    def __init__(self,
                 domain: Optional[Domain],
                 event_broker: Optional[EventChannel] = None) -> None:
        self.domain = domain
        self.event_broker = event_broker
        self.max_event_history = None

    @staticmethod
    def find_tracker_store(domain, store=None, event_broker=None):
        if store is None or store.type is None:
            return InMemoryTrackerStore(domain, event_broker=event_broker)
        elif store.type == 'redis':
            return RedisTrackerStore(domain=domain,
                                     host=store.url,
                                     event_broker=event_broker,
                                     **store.kwargs)
        elif store.type == 'mongod':
            return MongoTrackerStore(domain=domain,
                                     host=store.url,
                                     event_broker=event_broker,
                                     **store.kwargs)
        elif store.type.lower() == 'sql':
            return SQLTrackerStore(domain=domain,
                                   url=store.url,
                                   event_broker=event_broker,
                                   **store.kwargs)
        else:
            return TrackerStore.load_tracker_from_module_string(domain, store)

    @staticmethod
    def load_tracker_from_module_string(domain, store):
        custom_tracker = None
        try:
            custom_tracker = class_from_module_path(store.type)
        except (AttributeError, ImportError):
            logger.warning("Store type '{}' not found. "
                           "Using InMemoryTrackerStore instead"
                           .format(store.type))

        if custom_tracker:
            return custom_tracker(domain=domain,
                                  url=store.url, **store.kwargs)
        else:
            return InMemoryTrackerStore(domain)

    def get_or_create_tracker(self, sender_id, max_event_history=None):
        tracker = self.retrieve(sender_id)
        self.max_event_history = max_event_history
        if tracker is None:
            tracker = self.create_tracker(sender_id)
        return tracker

    def init_tracker(self, sender_id):
        if self.domain:
            return DialogueStateTracker(
                sender_id,
                self.domain.slots,
                max_event_history=self.max_event_history)
        else:
            return None

    def create_tracker(self, sender_id, append_action_listen=True):
        """Creates a new tracker for the sender_id.

        The tracker is initially listening."""

        tracker = self.init_tracker(sender_id)
        if tracker:
            if append_action_listen:
                tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
            self.save(tracker)
        return tracker

    def save(self, tracker):
        raise NotImplementedError()

    def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
        raise NotImplementedError()

    def stream_events(self, tracker: DialogueStateTracker) -> None:
        old_tracker = self.retrieve(tracker.sender_id)
        offset = len(old_tracker.events) if old_tracker else 0
        evts = tracker.events
        for evt in list(itertools.islice(evts, offset, len(evts))):
            body = {
                "sender_id": tracker.sender_id,
            }
            body.update(evt.as_dict())
            self.event_broker.publish(body)

    def keys(self):
        # type: () -> Optional[List[Text]]
        raise NotImplementedError()

    @staticmethod
    def serialise_tracker(tracker):
        dialogue = tracker.as_dialogue()
        return pickle.dumps(dialogue)

    def deserialise_tracker(self, sender_id, _json):
        dialogue = pickle.loads(_json)
        tracker = self.init_tracker(sender_id)
        tracker.recreate_from_dialogue(dialogue)
        return tracker


class InMemoryTrackerStore(TrackerStore):
    def __init__(self,
                 domain: Domain,
                 event_broker: Optional[EventChannel] = None
                 ) -> None:
        self.store = {}
        super(InMemoryTrackerStore, self).__init__(domain, event_broker)

    def save(self, tracker: DialogueStateTracker) -> None:
        if self.event_broker:
            self.stream_events(tracker)
        serialised = InMemoryTrackerStore.serialise_tracker(tracker)
        self.store[tracker.sender_id] = serialised

    def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
        if sender_id in self.store:
            logger.debug('Recreating tracker for '
                         'id \'{}\''.format(sender_id))
            return self.deserialise_tracker(sender_id, self.store[sender_id])
        else:
            logger.debug('Creating a new tracker for '
                         'id \'{}\'.'.format(sender_id))
            return None

    def keys(self) -> KeysView[Text]:
        return self.store.keys()


class RedisTrackerStore(TrackerStore):
    def keys(self):
        pass

    def __init__(self, domain, host='localhost',
                 port=6379, db=0, password=None, event_broker=None,
                 record_exp=None):

        import redis
        self.red = redis.StrictRedis(host=host, port=port, db=db,
                                     password=password)
        self.record_exp = record_exp
        super(RedisTrackerStore, self).__init__(domain, event_broker)

    def save(self, tracker, timeout=None):
        if self.event_broker:
            self.stream_events(tracker)

        if not timeout and self.record_exp:
            timeout = self.record_exp

        serialised_tracker = self.serialise_tracker(tracker)
        self.red.set(tracker.sender_id, serialised_tracker, ex=timeout)

    def retrieve(self, sender_id):
        stored = self.red.get(sender_id)
        if stored is not None:
            return self.deserialise_tracker(sender_id, stored)
        else:
            return None


class MongoTrackerStore(TrackerStore):
    def __init__(self,
                 domain,
                 host="mongodb://localhost:27017",
                 db="rasa",
                 username=None,
                 password=None,
                 auth_source="admin",
                 collection="conversations",
                 event_broker=None):
        from pymongo.database import Database
        from pymongo import MongoClient

        self.client = MongoClient(host,
                                  username=username,
                                  password=password,
                                  authSource=auth_source,
                                  # delay connect until process forking is done
                                  connect=False)

        self.db = Database(self.client, db)
        self.collection = collection
        super(MongoTrackerStore, self).__init__(domain, event_broker)

        self._ensure_indices()

    @property
    def conversations(self):
        return self.db[self.collection]

    def _ensure_indices(self):
        self.conversations.create_index("sender_id")

    def save(self, tracker, timeout=None):
        if self.event_broker:
            self.stream_events(tracker)

        state = tracker.current_state(EventVerbosity.ALL)

        self.conversations.update_one(
            {"sender_id": tracker.sender_id},
            {"$set": state},
            upsert=True)

    def retrieve(self, sender_id):
        stored = self.conversations.find_one({"sender_id": sender_id})

        # look for conversations which have used an `int` sender_id in the past
        # and update them.
        if stored is None and sender_id.isdigit():
            from pymongo import ReturnDocument
            stored = self.conversations.find_one_and_update(
                {"sender_id": int(sender_id)},
                {"$set": {"sender_id": str(sender_id)}},
                return_document=ReturnDocument.AFTER)

        if stored is not None:
            if self.domain:
                return DialogueStateTracker.from_dict(sender_id,
                                                      stored.get("events"),
                                                      self.domain.slots)
            else:
                logger.warning("Can't recreate tracker from mongo storage "
                               "because no domain is set. Returning `None` "
                               "instead.")
                return None
        else:
            return None

    def keys(self):
        return [c["sender_id"] for c in self.conversations.find()]


class SQLTrackerStore(TrackerStore):
    """Store which can save and retrieve trackers from an SQL database."""

    from sqlalchemy.ext.declarative import declarative_base

    Base = declarative_base()

    class SQLEvent(Base):
        from sqlalchemy import Column, Integer, String, Float

        __tablename__ = 'events'

        id = Column(Integer, primary_key=True)
        sender_id = Column(String, nullable=False)
        type_name = Column(String, nullable=False)
        timestamp = Column(Float)
        intent_name = Column(String)
        action_name = Column(String)
        data = Column(String)

    def __init__(self,
                 domain: Optional[Domain] = None,
                 dialect: Text = 'sqlite',
                 url: Text = None,
                 db: Text = 'rasa.db',
                 username: Text = None,
                 password: Text = None,
                 event_broker: Optional[EventChannel] = None) -> None:
        from sqlalchemy.orm import sessionmaker
        from sqlalchemy.engine.url import URL
        from sqlalchemy import create_engine

        engine_url = URL(dialect, username, password, url, database=db)

        logger.debug('Attempting to connect to database '
                     'via "{}"'.format(engine_url.__to_string__()))

        self.engine = create_engine(engine_url)
        self.session = sessionmaker(bind=self.engine)()

        self.Base.metadata.create_all(self.engine)

        logger.debug("Connection to SQL database '{}' "
                     "successful".format(db))

        super(SQLTrackerStore, self).__init__(domain, event_broker)

    def keys(self) -> List[Text]:
        """Collect all keys of the items stored in the database."""
        # noinspection PyUnresolvedReferences
        return self.SQLEvent.__table__.columns.keys()

    def retrieve(self, sender_id: Text) -> DialogueStateTracker:
        """Create a tracker from all previously stored events."""

        query = self.session.query(self.SQLEvent)
        result = query.filter_by(sender_id=sender_id).all()
        events = [json.loads(event.data) for event in result]

        if self.domain and len(events) > 0:
            logger.debug("Recreating tracker "
                         "from sender id '{}'".format(sender_id))

            return DialogueStateTracker.from_dict(sender_id, events,
                                                  self.domain.slots)
        else:
            logger.debug("Can't retrieve tracker matching"
                         "sender id '{}' from SQL storage.  "
                         "Returning `None` instead.".format(sender_id))

    def save(self, tracker: DialogueStateTracker) -> None:
        """Update database with events from the current conversation."""

        if self.event_broker:
            self.stream_events(tracker)

        events = self._additional_events(tracker)  # only store recent events

        for event in events:

            data = event.as_dict()

            intent = data.get("parse_data", {}).get("intent", {}).get("name")
            action = data.get("name")
            timestamp = data.get("timestamp")

            # noinspection PyArgumentList
            self.session.add(self.SQLEvent(sender_id=tracker.sender_id,
                                           type_name=event.type_name,
                                           timestamp=timestamp,
                                           intent_name=intent,
                                           action_name=action,
                                           data=json.dumps(data)))
        self.session.commit()

        logger.debug("Tracker with sender_id '{}' "
                     "stored to database".format(tracker.sender_id))

    def _additional_events(self, tracker: DialogueStateTracker) -> Iterator:
        """Return events from the tracker which aren't currently stored."""

        from sqlalchemy import func
        query = self.session.query(func.max(self.SQLEvent.timestamp))
        max_timestamp = query.filter_by(sender_id=tracker.sender_id).scalar()

        if max_timestamp is None:
            max_timestamp = 0

        latest_events = []

        for event in reversed(tracker.events):
            if event.timestamp > max_timestamp:
                latest_events.append(event)
            else:
                break

        return reversed(latest_events)