from typing import Iterable, Dict, Tuple, List import razdel import nltk from allennlp.data.instance import Instance from allennlp.data.dataset_readers.dataset_reader import DatasetReader from allennlp.data.tokenizers import Tokenizer, Token from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer from allennlp.common.util import START_SYMBOL, END_SYMBOL from allennlp.data.fields import TextField, ListField, SequenceLabelField class SummarizationSentencesTaggerReader(DatasetReader): def __init__(self, tokenizer: Tokenizer, language: str, source_token_indexers: Dict[str, TokenIndexer] = None, max_sentences_count: int = 100, sentence_max_tokens: int = 100, lowercase: bool = True, lazy: bool = True) -> None: super().__init__(lazy=lazy) self._tokenizer = tokenizer self._lowercase = lowercase self._language = language self._max_sentences_count = max_sentences_count self._sentence_max_tokens = sentence_max_tokens self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()} def _read(self, file_path: str) -> Iterable[Instance]: for text, summary, sentences, tags in self.parse_set(file_path): assert sentences is None and tags is None or len(sentences) == len(tags) instance = self.text_to_instance(text, sentences, tags) yield instance def text_to_instance(self, text: str, sentences: List[str] = None, tags: List[int] = None) -> Instance: if sentences is None: if self._language == "ru": sentences = [s.text for s in razdel.sentenize(text)] else: sentences = nltk.tokenize.sent_tokenize(text) sentences_tokens = [] for sentence in sentences[:self._max_sentences_count]: sentence = sentence.lower() if self._lowercase else sentence tokens = self._tokenizer.tokenize(sentence)[:self._sentence_max_tokens] tokens.insert(0, Token(START_SYMBOL)) tokens.append(Token(END_SYMBOL)) indexed_tokens = TextField(tokens, self._source_token_indexers) sentences_tokens.append(indexed_tokens) sentences_tokens_indexed = ListField(sentences_tokens) result = {'source_sentences': sentences_tokens_indexed} if tags: result["sentences_tags"] = SequenceLabelField(tags[:self._max_sentences_count], sentences_tokens_indexed) return Instance(result) def parse_set(self, path: str) -> Iterable[Tuple[List[str], List[int]]]: raise NotImplementedError()