""" only replace head nouns """
from functools import reduce
from os.path import join as pjoin
from pattern import *
import spacy
from word2number import w2n
import argparse
import os

# load spacy
nlp = spacy.load('en')


def detect_number(text):
    left = 0
    detect_results = list()
    while left < len(text):
        right = left + 1
        if text[left].pos_ == 'NUM':
            right = left + 1
            while right < len(text) and text[right].pos_ == 'NUM':
                right += 1
            if right == len(text):
                continue
            number = ' '.join([word.text for word in text[left: right]])
            if not valid(number):
                left = right
                continue
            num_range = [left, right]
            while right < len(text) and text[right].pos_ in ['ADJ', 'CCONJ']:
                right += 1
            if right == len(text):
                continue
            if text[right].pos_ == 'NOUN':
                while right < len(text) and text[right].pos_ == 'NOUN':
                    right += 1
                n = text[right - 1].text
                n_index = right - 1
            else:
                left = right
                continue
            detect_results.append([number, num_range, n, n_index])
        elif text[left].pos_ == 'DET' and (text[left].text == 'an' or text[left].text == 'a'):
            right = left + 1
            num_range = [left, right]
            while right < len(text) and text[right].pos_ in ['ADJ', 'CCONJ']:
                right += 1
            if right == len(text):
                continue
            if text[right].pos_ == 'NOUN':
                while right < len(text) and text[right].pos_ == 'NOUN':
                    right += 1
                n = text[right - 1].text
                n_index = right - 1
            else:
                left = right
                continue
            detect_results.append([text[left].text, num_range, n, n_index])
        left = right
    return detect_results


def valid(word):
    # check number
    try:
        numbers = [w2n.word_to_num(it) for it in word.split()]
        if len(numbers) > 1:
            return False
    except ValueError:
        return False
    # check English
    chars = map(lambda c: 'a' <= c <= 'z', word)
    return reduce(lambda a, b: a or b, chars, False)


prefix_list = ['test', 'dev', 'train']


# rule:
#   1. NUM string
#   2. after NUM should be NOUN
#   3. Noun described by NUM should be the last one of the first NOUN string after NUM
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_name', type=str, default='resnet152_numex_precomp')
    parser.add_argument('--data_path', type=str, default='../data/')
    args = parser.parse_args()
    # load all coco captions
    coco_captions = list()
    number_set = set()
    for prefix in prefix_list:
        os.system('mkdir {:s}/{:s}/{:s}_ex'.format(args.data_path, args.data_name, prefix))
        path_coco_captions = pjoin('{:s}/{:s}/'.format(args.data_path, args.data_name), prefix + '_caps.txt')
        coco_captions += open(path_coco_captions).readlines()
    # count word-frequency for nouns and adjectives
    frequency = dict()
    number_set = {'two', 'six', 'eight', 'five', 'fifteen',
                  'nine', 'a', 'an', 'four', 'three',
                  'seven', 'fourteen', 'twenty', 'eighteen', 'ten',
                  'one', 'thirty', 'eleven', 'twelve', 'thirteen',
                  'sixteen', 'seventeen', 'nineteen', 'sixty', 'forty',
                  'fifty', 'seventy', 'eighty'}
    # process_numbers
    number_dict = dict()
    for item in number_set:
        if item in ['a', 'an', 'one']:
            number_dict[item] = 1
        else:
            try:
                number_dict[item] = w2n.word_to_num(item)
            except ValueError:
                print('value error:', item)
    # replace these words with words which are similar to themselves but have different meanings
    for prefix in prefix_list:
        path_coco_captions = pjoin('{:s}/{:s}/'.format(args.data_path, args.data_name), prefix + '_caps.txt')
        real_captions = open(path_coco_captions).readlines()
        for i, caption in enumerate(real_captions):
            replacement_cnt = 0
            fout = open('{:s}/{:s}/{:s}_ex/{:d}.txt'.format(args.data_path, args.data_name, prefix, i), 'w')
            text = nlp(caption.lower())
            number_info = detect_number(text)
            # replace
            for item in number_info:
                number_range = item[1]
                w = item[0]
                noun = item[2]
                noun_index = item[3]
                if w not in number_dict:
                    continue
                # enumerate all possible words for replacement
                for r in number_dict:
                    if number_dict[r] == number_dict[w]:
                        continue
                    word_list = [x.text for x in text]
                    # singularize or pluralize
                    if number_dict[r] == 1:
                        new_noun = singularize(noun)
                    elif number_dict[w] == 1:
                        new_noun = pluralize(noun)
                    else:
                        new_noun = noun
                    word_list[noun_index] = new_noun
                    new_word_list = word_list[:number_range[0]] + [r] + word_list[number_range[1]:]
                    replacement = ' '.join(new_word_list)
                    fout.write(replacement)
                    replacement_cnt += 1
            if prefix in ['train', 'dev']:
                for j in range(replacement_cnt, 5):
                    fout.write(' '.join(['<unk>' for _ in caption.strip().split()]) + '\n')
            fout.close()
            print(prefix, i, replacement_cnt)