import json
import re
import string

import nltk
import numpy as np
import pandas as pd
from nltk.corpus import stopwords
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.tokenize import TweetTokenizer
from sklearn.externals import joblib
from steppy.base import BaseTransformer

lem = WordNetLemmatizer()
tokenizer = TweetTokenizer()

nltk.download('wordnet')
nltk.download('stopwords')

eng_stopwords = set(stopwords.words("english"))
with open('steps/resources/apostrophes.json', 'r') as f:
    APOSTROPHES_WORDS = json.load(f)


class WordListFilter(BaseTransformer):
    def __init__(self, word_list_filepath):
        super().__init__()
        self.word_set = self._read_data(word_list_filepath)

    def transform(self, X):
        X = self._transform(X)
        return {'X': X}

    def _transform(self, X):
        X = pd.DataFrame(X, columns=['text']).astype(str)
        X['text'] = X['text'].apply(self._filter_words)
        return X['text'].values

    def _filter_words(self, x):
        x = x.lower()
        x = ' '.join([w for w in x.split() if w in self.word_set])
        return x

    def _read_data(self, filepath):
        with open(filepath, 'r+') as f:
            data = f.read()
        return set(data.split('\n'))

    def load(self, filepath):
        return self

    def persist(self, filepath):
        joblib.dump({}, filepath)


class TextCleaner(BaseTransformer):
    def __init__(self,
                 drop_punctuation,
                 drop_newline,
                 drop_multispaces,
                 all_lower_case,
                 fill_na_with,
                 deduplication_threshold,
                 apostrophes,
                 use_stopwords):
        super().__init__()
        self.drop_punctuation = drop_punctuation
        self.drop_newline = drop_newline
        self.drop_multispaces = drop_multispaces
        self.all_lower_case = all_lower_case
        self.fill_na_with = fill_na_with
        self.deduplication_threshold = deduplication_threshold
        self.apostrophes = apostrophes
        self.use_stopwords = use_stopwords

    def transform(self, X):
        X = pd.DataFrame(X, columns=['text']).astype(str)
        X['text'] = X['text'].apply(self._transform)
        if self.fill_na_with:
            X['text'] = X['text'].fillna(self.fill_na_with).values
        return {'X': X['text'].values}

    def _transform(self, x):
        if self.all_lower_case:
            x = self._lower(x)
        if self.drop_punctuation:
            x = self._remove_punctuation(x)
        if self.drop_newline:
            x = self._remove_newline(x)
        if self.drop_multispaces:
            x = self._substitute_multiple_spaces(x)
        if self.deduplication_threshold is not None:
            x = self._deduplicate(x)
        if self.apostrophes:
            x = self._apostrophes(x)
        if self.use_stopwords:
            x = self._use_stopwords(x)
        return x

    def _use_stopwords(self, x):
        words = tokenizer.tokenize(x)
        words = [w for w in words if not w in eng_stopwords]
        x = " ".join(words)
        return x

    def _apostrophes(self, x):
        words = tokenizer.tokenize(x)
        words = [APOSTROPHES_WORDS[word] if word in APOSTROPHES_WORDS else word for word in words]
        words = [lem.lemmatize(word, "v") for word in words]
        words = [w for w in words if not w in eng_stopwords]
        x = " ".join(words)
        return x

    def _lower(self, x):
        return x.lower()

    def _remove_punctuation(self, x):
        return re.sub(r'[^\w\s]', ' ', x)

    def _remove_newline(self, x):
        x = x.replace('\n', ' ')
        x = x.replace('\n\n', ' ')
        return x

    def _substitute_multiple_spaces(self, x):
        return ' '.join(x.split())

    def _deduplicate(self, x):
        word_list = x.split()
        num_words = len(word_list)
        if num_words == 0:
            return x
        else:
            num_unique_words = len(set(word_list))
            unique_ratio = num_words / num_unique_words
            if unique_ratio > self.deduplication_threshold:
                x = ' '.join(x.split()[:num_unique_words])
            return x

    def load(self, filepath):
        params = joblib.load(filepath)
        self.drop_punctuation = params['drop_punctuation']
        self.all_lower_case = params['all_lower_case']
        self.fill_na_with = params['fill_na_with']
        return self

    def persist(self, filepath):
        params = {'drop_punctuation': self.drop_punctuation,
                  'all_lower_case': self.all_lower_case,
                  'fill_na_with': self.fill_na_with,
                  }
        joblib.dump(params, filepath)


class TextCounter(BaseTransformer):
    def transform(self, X):
        X = pd.DataFrame(X, columns=['text']).astype(str)
        X = X['text'].apply(self._transform)
        X['caps_vs_length'] = self._caps_vs_length(X)
        X['num_symbols'] = X['text'].apply(lambda comment: sum(comment.count(w) for w in '*&$%'))
        X['num_words'] = X['text'].apply(lambda comment: len(comment.split()))
        X['num_unique_words'] = X['text'].apply(lambda comment: len(set(w for w in comment.split())))
        X['words_vs_unique'] = self._words_vs_unique(X)
        X['mean_word_len'] = X['text'].apply(lambda x: np.mean([len(w) for w in str(x).split()]))
        X.drop('text', axis=1, inplace=True)
        X.fillna(0.0, inplace=True)
        return {'X': X}

    def _transform(self, x):
        features = {}
        features['text'] = x
        features['char_count'] = char_count(x)
        features['word_count'] = word_count(x)
        features['punctuation_count'] = punctuation_count(x)
        features['upper_case_count'] = upper_case_count(x)
        features['lower_case_count'] = lower_case_count(x)
        features['digit_count'] = digit_count(x)
        features['space_count'] = space_count(x)
        features['newline_count'] = newline_count(x)
        return pd.Series(features)

    def _caps_vs_length(self, X):
        try:
            return X.apply(lambda row: float(row['upper_case_count']) / float(row['char_count']), axis=1)
        except ZeroDivisionError:
            return 0

    def _words_vs_unique(self, X):
        try:
            return X['num_unique_words'] / X['num_words']
        except ZeroDivisionError:
            return 0

    def load(self, filepath):
        return self

    def persist(self, filepath):
        joblib.dump({}, filepath)


def char_count(x):
    return len(x)


def word_count(x):
    return len(x.split())


def newline_count(x):
    return x.count('\n')


def upper_case_count(x):
    return sum(c.isupper() for c in x)


def lower_case_count(x):
    return sum(c.islower() for c in x)


def digit_count(x):
    return sum(c.isdigit() for c in x)


def space_count(x):
    return sum(c.isspace() for c in x)


def punctuation_count(x):
    return occurrence(x, string.punctuation)


def occurrence(s1, s2):
    return sum([1 for x in s1 if x in s2])