""" only replace head nouns """ from nltk.corpus import wordnet as wn from os.path import join as pjoin import spacy import argparse import os import random # load spacy nlp = spacy.load('en') def valid(word, replacement): if replacement.lower() != replacement: return False synset_word = wn.synsets(word) synset_replacement = wn.synsets(replacement) for item_1 in synset_word: for item_2 in synset_replacement: if item_1 == item_2: return False # one-step hypernymy/hyponymy check for item_1 in synset_word: for subitem in item_1.hypernyms(): for item_2 in synset_replacement: if item_2 == subitem: return False for subitem in item_1.hyponyms(): for item_2 in synset_replacement: if item_2 == subitem: return False return True def English(word): for k in range(len(word)): if word[k] != ' ' and not ('a' <= word[k] <= 'z'): return False return True def tag_noun_chunk(text): text = text text_info = nlp(text) ws = list([x.text for x in text_info]) replace_tag = [False] * len(ws) for chunk in text_info.noun_chunks: if not English(chunk.text): continue replace_tag[chunk.end - 1] = True return ws, replace_tag prefix_list = ['train', 'dev', 'test'] if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_name', type=str, default='resnet152_nounex_precomp') parser.add_argument('--data_path', type=str, default='../data') args = parser.parse_args() # load all coco captions coco_captions = list() for prefix in prefix_list: os.system('mkdir {:s}/{:s}/{:s}_ex'.format(args.data_path, args.data_name, prefix)) path_coco_captions = pjoin('{:s}/{:s}/'.format(args.data_path, args.data_name), prefix + '_caps.txt') coco_captions += open(path_coco_captions).readlines() # count word-frequency for nouns and adjectives frequency = dict() for i in range(0, len(coco_captions), 10000): # default: treat the last word of a noun chunk as its head caption = ' '.join(coco_captions[i: min(i+10000, len(coco_captions))]) caption = nlp(caption.lower()) for word in caption: if word.pos_ == 'NOUN' and word.lemma_ != '--PRON--': word = word.text if word in frequency: frequency[word] += 1 else: frequency[word] = 1 print(i) # sort and select most frequent nouns frequency_pair = sorted([(token, frequency[token]) for token in frequency], key=lambda x: x[1], reverse=True) frequency_pair = list(filter(lambda x: x[1] >= 200, frequency_pair)) # frequency >= 200 # load concreteness concreteness = list(filter( lambda x: float(x[1]) > 0.6, [line.split() for line in open('{:s}/concreteness.txt'.format(args.data_path)).readlines()])) concrete_words = set([item[0] for item in concreteness]) # filter out concrete words frequency_pair = list(filter(lambda x: x[0] in concrete_words, frequency_pair)) frequent_words = [item[0] for item in frequency_pair] frequent_words_set = set(frequent_words) # get candidates for replacement candidate_words = dict() for w in frequent_words: candidate_words[w] = list(filter( lambda x: valid(w, x), [item for item in frequent_words_set])) # replace these words with words which are similar to themselves but have different meanings for prefix in prefix_list: path_coco_captions = pjoin('{:s}/{:s}/'.format(args.data_path, args.data_name), prefix + '_caps.txt') real_captions = open(path_coco_captions).readlines() invalid_word_set = set() for i in range(len(real_captions)): caption = real_captions[i] # load 5 recent captions for invalid word set if i % 5 == 0: invalid_word_set = set() for j in range(i, i + 5): for w in real_captions[j].split(): invalid_word_set.add(w) replacement_cnt = 0 fout = open('{:s}/{:s}/{:s}_ex/{:d}.txt'.format(args.data_path, args.data_name, prefix, i), 'w') caption = caption.lower() words, noun_chunk_tag = tag_noun_chunk(caption) replaced_words = [w for w in words] all_replacements = list() for j, token in enumerate(words): if not noun_chunk_tag[j]: continue if token in candidate_words: for candidate in candidate_words[token]: if candidate in invalid_word_set: # skip invalid words continue replaced_words[j] = candidate all_replacements.append(' '.join(replaced_words)) replacement_cnt += 1 replaced_words[j] = words[j] random.shuffle(all_replacements) for line in all_replacements: fout.write(line) if prefix in ['train', 'dev']: for j in range(replacement_cnt, 5): fout.write(' '.join(['<unk>' for _ in caption.strip().split()]) + '\n') fout.close() print(prefix, i, replacement_cnt)