# -*- coding: utf-8 -*- # # Copyright 2015 Benjamin Kiessling # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. """ Utility functions for data loading and training of VGSL networks. """ import json import regex import torch import unicodedata import numpy as np import pkg_resources import bidi.algorithm as bd import shapely.geometry as geom import torch.nn.functional as F import torchvision.transforms.functional as tf from os import path from PIL import Image, ImageDraw from itertools import groupby from collections import Counter, defaultdict from torchvision import transforms from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence from typing import Dict, List, Tuple, Iterable, Sequence, Callable, Optional, Any, Union, cast from skimage.draw import polygon from kraken.lib.xml import parse_alto, parse_page, parse_xml from kraken.lib.util import is_bitonal from kraken.lib.codec import PytorchCodec from kraken.lib.models import TorchSeqRecognizer from kraken.lib.segmentation import extract_polygons, calculate_polygonal_environment from kraken.lib.exceptions import KrakenInputException from kraken.lib.lineest import CenterNormalizer, dewarp __all__ = ['BaselineSet', 'GroundTruthDataset', 'compute_error', 'generate_input_transforms', 'preparse_xml_data'] import logging logger = logging.getLogger(__name__) def generate_input_transforms(batch: int, height: int, width: int, channels: int, pad: int, valid_norm: bool = True, force_binarization=False) -> transforms.Compose: """ Generates a torchvision transformation converting a PIL.Image into a tensor usable in a network forward pass. Args: batch (int): mini-batch size height (int): height of input image in pixels width (int): width of input image in pixels channels (int): color channels of input pad (int): Amount of padding on horizontal ends of image valid_norm (bool): Enables/disables baseline normalization as a valid preprocessing step. If disabled we will fall back to standard scaling. force_binarization (bool): Forces binarization of input images using the nlbin algorithm. Returns: A torchvision transformation composition converting the input image to the appropriate tensor. """ scale = (height, width) # type: Tuple[int, int] center_norm = False mode = 'RGB' if channels == 3 else 'L' if height == 1 and width == 0 and channels > 3: perm = (1, 0, 2) scale = (channels, 0) if valid_norm: center_norm = True mode = 'L' elif height > 1 and width == 0 and channels in (1, 3): perm = (0, 1, 2) if valid_norm and channels == 1: center_norm = True elif height == 0 and width > 1 and channels in (1, 3): perm = (0, 1, 2) # fixed height and width image => bicubic scaling of the input image, disable padding elif height > 0 and width > 0 and channels in (1, 3): perm = (0, 1, 2) pad = 0 elif height == 0 and width == 0 and channels in (1, 3): perm = (0, 1, 2) pad = 0 else: raise KrakenInputException('Invalid input spec {}, {}, {}, {}, {}'.format(batch, height, width, channels, pad)) if mode != 'L' and force_binarization: raise KrakenInputException('Invalid input spec {}, {}, {}, {} in' ' combination with forced binarization.'.format(batch, height, width, channels, pad)) out_transforms = [] out_transforms.append(transforms.Lambda(lambda x: x.convert(mode))) if force_binarization: out_transforms.append(transforms.Lambda(lambda x: nlbin(im))) # dummy transforms to ensure we can determine color mode of input material # from first two transforms. It's stupid but it works. out_transforms.append(transforms.Lambda(lambda x: x)) if scale != (0, 0): if center_norm: lnorm = CenterNormalizer(scale[0]) out_transforms.append(transforms.Lambda(lambda x: dewarp(lnorm, x))) out_transforms.append(transforms.Lambda(lambda x: x.convert(mode))) else: out_transforms.append(transforms.Lambda(lambda x: _fixed_resize(x, scale, Image.LANCZOS))) if pad: out_transforms.append(transforms.Pad((pad, 0), fill=255)) out_transforms.append(transforms.ToTensor()) # invert out_transforms.append(transforms.Lambda(lambda x: x.max() - x)) out_transforms.append(transforms.Lambda(lambda x: x.permute(*perm))) return transforms.Compose(out_transforms) def _fixed_resize(img, size, interpolation=Image.LANCZOS): """ Doesn't do the annoying runtime scale dimension switching the default pytorch transform does. Args: img (PIL.Image): image to resize size (tuple): Tuple (height, width) """ w, h = img.size oh, ow = size if oh == 0: oh = int(h * ow/w) elif ow == 0: ow = int(w * oh/h) img = img.resize((ow, oh), interpolation) return img def _fast_levenshtein(seq1: Sequence[Any], seq2: Sequence[Any]) -> int: oneago = None thisrow = list(range(1, len(seq2) + 1)) + [0] rows = [thisrow] for x in range(len(seq1)): oneago, thisrow = thisrow, [0] * len(seq2) + [x + 1] for y in range(len(seq2)): delcost = oneago[y] + 1 addcost = thisrow[y - 1] + 1 subcost = oneago[y - 1] + (seq1[x] != seq2[y]) thisrow[y] = min(delcost, addcost, subcost) rows.append(thisrow) return thisrow[len(seq2) - 1] def global_align(seq1: Sequence[Any], seq2: Sequence[Any]) -> Tuple[int, List[str], List[str]]: """ Computes a global alignment of two strings. Args: seq1 (Sequence[Any]): seq2 (Sequence[Any]): Returns a tuple (distance, list(algn1), list(algn2)) """ # calculate cost and direction matrix cost = [[0] * (len(seq2) + 1) for x in range(len(seq1) + 1)] for i in range(1, len(cost)): cost[i][0] = i for i in range(1, len(cost[0])): cost[0][i] = i direction = [[(0, 0)] * (len(seq2) + 1) for x in range(len(seq1) + 1)] direction[0] = [(0, x) for x in range(-1, len(seq2))] for i in range(-1, len(direction) - 1): direction[i + 1][0] = (i, 0) for i in range(1, len(cost)): for j in range(1, len(cost[0])): delcost = ((i - 1, j), cost[i - 1][j] + 1) addcost = ((i, j - 1), cost[i][j - 1] + 1) subcost = ((i - 1, j - 1), cost[i - 1][j - 1] + (seq1[i - 1] != seq2[j - 1])) best = min(delcost, addcost, subcost, key=lambda x: x[1]) cost[i][j] = best[1] direction[i][j] = best[0] d = cost[-1][-1] # backtrace algn1: List[Any] = [] algn2: List[Any] = [] i = len(direction) - 1 j = len(direction[0]) - 1 while direction[i][j] != (-1, 0): k, l = direction[i][j] if k == i - 1 and l == j - 1: algn1.insert(0, seq1[i - 1]) algn2.insert(0, seq2[j - 1]) elif k < i: algn1.insert(0, seq1[i - 1]) algn2.insert(0, '') elif l < j: algn1.insert(0, '') algn2.insert(0, seq2[j - 1]) i, j = k, l return d, algn1, algn2 def compute_confusions(algn1: Sequence[str], algn2: Sequence[str]): """ Compute confusion matrices from two globally aligned strings. Args: align1 (Sequence[str]): sequence 1 align2 (Sequence[str]): sequence 2 Returns: A tuple (counts, scripts, ins, dels, subs) with `counts` being per-character confusions, `scripts` per-script counts, `ins` a dict with per script insertions, `del` an integer of the number of deletions, `subs` per script substitutions. """ counts: Dict[Tuple[str, str], int] = Counter() with pkg_resources.resource_stream(__name__, 'scripts.json') as fp: script_map = json.load(fp) def _get_script(c): for s, e, n in script_map: if ord(c) == s or (e and s <= ord(c) <= e): return n return 'Unknown' scripts: Dict[Tuple[str, str], int] = Counter() ins: Dict[Tuple[str, str], int] = Counter() dels: int = 0 subs: Dict[Tuple[str, str], int] = Counter() for u,v in zip(algn1, algn2): counts[(u, v)] += 1 for k, v in counts.items(): if k[0] == '': dels += v else: script = _get_script(k[0]) scripts[script] += v if k[1] == '': ins[script] += v elif k[0] != k[1]: subs[script] += v return counts, scripts, ins, dels, subs def compute_error(model: TorchSeqRecognizer, validation_set: Iterable[Dict[str, torch.Tensor]]) -> Tuple[int, int]: """ Computes error report from a model and a list of line image-text pairs. Args: model (kraken.lib.models.TorchSeqRecognizer): Model used for recognition validation_set (list): List of tuples (image, text) for validation Returns: A tuple with total number of characters and edit distance across the whole validation set. """ total_chars = 0 error = 0 for batch in validation_set: preds = model.predict_string(batch['image'], batch['seq_lens']) total_chars += batch['target_lens'].sum() for pred, text in zip(preds, batch['target']): error += _fast_levenshtein(pred, text) return total_chars, error def preparse_xml_data(filenames, format_type='xml', repolygonize=False): """ Loads training data from a set of xml files. Extracts line information from Page/ALTO xml files for training of recognition models. Args: filenames (list): List of XML files. format_type (str): Either `page`, `alto` or `xml` for autodetermination. repolygonize (bool): (Re-)calculates polygon information using the kraken algorithm. Returns: A list of dicts {'text': text, 'baseline': [[x0, y0], ...], 'boundary': [[x0, y0], ...], 'image': PIL.Image}. """ training_pairs = [] if format_type == 'xml': parse_fn = parse_xml elif format_type == 'alto': parse_fn = parse_alto elif format_type == 'page': parse_fn = parse_page else: raise Exception(f'invalid format {format_type} for preparse_xml_data') for fn in filenames: try: data = parse_fn(fn) except KrakenInputException as e: logger.warning(e) continue try: Image.open(data['image']) except FileNotFoundError as e: logger.warning(f'Could not open file {e.filename} in {fn}') continue if repolygonize: logger.info('repolygonizing {} lines in {}'.format(len(data['lines']), data['image'])) data['lines'] = _repolygonize(data['image'], data['lines']) for line in data['lines']: training_pairs.append({'image': data['image'], **line}) return training_pairs def _repolygonize(im: Image.Image, lines): """ Helper function taking an output of the lib.xml parse_* functions and recalculating the contained polygonization. Args: im (Image.Image): Input image lines (list): List of dicts [{'boundary': [[x0, y0], ...], 'baseline': [[x0, y0], ...], 'text': 'abcvsd'}, {...] Returns: A data structure `lines` with a changed polygonization. """ im = Image.open(im).convert('L') polygons = calculate_polygonal_environment(im, [x['baseline'] for x in lines]) return [{'boundary': polygon, 'baseline': orig['baseline'], 'text': orig['text']} for orig, polygon in zip(lines, polygons)] def collate_sequences(batch): """ Sorts and pads sequences. """ sorted_batch = sorted(batch, key=lambda x: x['image'].shape[2], reverse=True) seqs = [x['image'] for x in sorted_batch] seq_lens = torch.LongTensor([seq.shape[2] for seq in seqs]) max_len = seqs[0].shape[2] seqs = torch.stack([F.pad(seq, pad=(0, max_len-seq.shape[2])) for seq in seqs]) if isinstance(sorted_batch[0]['target'], str): labels = [x['target'] for x in sorted_batch] else: labels = torch.cat([x['target'] for x in sorted_batch]).long() label_lens = torch.LongTensor([len(x['target']) for x in sorted_batch]) return {'image': seqs, 'target': labels, 'seq_lens': seq_lens, 'target_lens': label_lens} class InfiniteDataLoader(DataLoader): """ Version of DataLoader that auto-reinitializes the iterator once it is exhausted. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dataset_iter = super().__iter__() def __iter__(self): return self def __next__(self): try: sample = next(self.dataset_iter) except StopIteration: self.dataset_iter = super().__iter__() sample = next(self.dataset_iter) return sample class PolygonGTDataset(Dataset): """ Dataset for training a line recognition model from polygonal/baseline data. """ def __init__(self, normalization: Optional[str] = None, whitespace_normalization: bool = True, reorder: bool = True, im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]), preload: bool = True, augmentation: bool = False) -> None: self._images = [] # type: Union[List[Image], List[torch.Tensor]] self._gt = [] # type: List[str] self.alphabet = Counter() # type: Counter self.text_transforms = [] # type: List[Callable[[str], str]] # split image transforms into two. one part giving the final PIL image # before conversion to a tensor and the actual tensor conversion part. self.head_transforms = transforms.Compose(im_transforms.transforms[:2]) self.tail_transforms = transforms.Compose(im_transforms.transforms[2:]) self.transforms = im_transforms self.preload = preload self.aug = None self.seg_type = 'baselines' # built text transformations if normalization: self.text_transforms.append(lambda x: unicodedata.normalize(cast(str, normalization), x)) if whitespace_normalization: self.text_transforms.append(lambda x: regex.sub('\s', ' ', x).strip()) if reorder: self.text_transforms.append(bd.get_display) if augmentation: from albumentations import ( Compose, ToFloat, FromFloat, Flip, OneOf, MotionBlur, MedianBlur, Blur, ShiftScaleRotate, OpticalDistortion, ElasticTransform, RandomBrightnessContrast, ) self.aug = Compose([ ToFloat(), OneOf([ MotionBlur(p=0.2), MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=3, p=0.2), OneOf([ OpticalDistortion(p=0.3), ElasticTransform(p=0.1), ], p=0.2), ], p=0.5) self.im_mode = '1' def add(self, image: Union[str, Image.Image], text: str, baseline: List[Tuple[int, int]], boundary: List[Tuple[int, int]], *args, **kwargs): """ Adds a line to the dataset. Args: im (path): Path to the whole page image text (str): Transcription of the line. baseline (list): A list of coordinates [[x0, y0], ..., [xn, yn]]. boundary (list): A polygon mask for the line. """ for func in self.text_transforms: text = func(text) if not text: raise KrakenInputException('Text line is empty after transformations') if self.preload: if not isinstance(image, Image.Image): im = Image.open(image) im, _ = next(extract_polygons(im, {'type': 'baselines', 'lines': [{'baseline': baseline, 'boundary': boundary}]})) try: im = self.head_transforms(im) if not is_bitonal(im): self.im_mode = im.mode im = self.tail_transforms(im) except ValueError: raise KrakenInputException('Image transforms failed on {}'.format(image)) self._images.append(im) else: self._images.append((image, baseline, boundary)) self._gt.append(text) self.alphabet.update(text) def encode(self, codec: Optional[PytorchCodec] = None) -> None: """ Adds a codec to the dataset and encodes all text lines. Has to be run before sampling from the dataset. """ if codec: self.codec = codec else: self.codec = PytorchCodec(''.join(self.alphabet.keys())) self.training_set = [] # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]] for im, gt in zip(self._images, self._gt): self.training_set.append((im, self.codec.encode(gt))) def no_encode(self) -> None: """ Creates an unencoded dataset. """ self.training_set = [] # type: List[Tuple[Union[Image, torch.Tensor], str]] for im, gt in zip(self._images, self._gt): self.training_set.append((im, gt)) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: if self.preload: x, y = self.training_set[index] if self.aug: x = x.permute((1, 2, 0)).numpy() o = self.aug(image=x) x = torch.tensor(o['image'].transpose(2, 0, 1)) return {'image': x, 'target': y} else: item = self.training_set[index] try: logger.debug('Attempting to load {}'.format(item[0])) im = item[0][0] if not isinstance(im, Image.Image): im = Image.open(im) im, _ = next(extract_polygons(im, {'type': 'baselines', 'lines': [{'baseline': item[0][1], 'boundary': item[0][2]}]})) im = self.head_transforms(im) if not is_bitonal(im): self.im_mode = im.mode im = self.tail_transforms(im) if self.aug: im = im.permute((1, 2, 0)).numpy() o = self.aug(image=im) im = torch.tensor(o['image'].transpose(2, 0, 1)) return {'image': im, 'target': item[1]} except Exception: idx = np.random.randint(0, len(self.training_set)) logger.debug('Failed. Replacing with sample {}'.format(idx)) return self[np.random.randint(0, len(self.training_set))] def __len__(self) -> int: return len(self.training_set) class GroundTruthDataset(Dataset): """ Dataset for training a line recognition model. All data is cached in memory. """ def __init__(self, split: Callable[[str], str] = lambda x: path.splitext(x)[0], suffix: str = '.gt.txt', normalization: Optional[str] = None, whitespace_normalization: bool = True, reorder: bool = True, im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]), preload: bool = True, augmentation: bool = False) -> None: """ Reads a list of image-text pairs and creates a ground truth set. Args: split (func): Function for generating the base name without extensions from paths suffix (str): Suffix to attach to image base name for text retrieval mode (str): Image color space. Either RGB (color) or L (grayscale/bw). Only L is compatible with vertical scaling/dewarping. scale (int, tuple): Target height or (width, height) of dewarped line images. Vertical-only scaling is through CenterLineNormalizer, resizing with Lanczos interpolation. Set to 0 to disable. normalization (str): Unicode normalization for gt whitespace_normalization (str): Normalizes unicode whitespace and strips whitespace. reorder (bool): Whether to rearrange code points in "display"/LTR order im_transforms (func): Function taking an PIL.Image and returning a tensor suitable for forward passes. preload (bool): Enables preloading and preprocessing of image files. """ self.suffix = suffix self.split = lambda x: split(x) + self.suffix self._images = [] # type: Union[List[Image], List[torch.Tensor]] self._gt = [] # type: List[str] self.alphabet = Counter() # type: Counter self.text_transforms = [] # type: List[Callable[[str], str]] # split image transforms into two. one part giving the final PIL image # before conversion to a tensor and the actual tensor conversion part. self.head_transforms = transforms.Compose(im_transforms.transforms[:2]) self.tail_transforms = transforms.Compose(im_transforms.transforms[2:]) self.aug = None self.preload = preload self.seg_type = 'bbox' # built text transformations if normalization: self.text_transforms.append(lambda x: unicodedata.normalize(cast(str, normalization), x)) if whitespace_normalization: self.text_transforms.append(lambda x: regex.sub('\s', ' ', x).strip()) if reorder: self.text_transforms.append(bd.get_display) if augmentation: from albumentations import ( Compose, ToFloat, FromFloat, Flip, OneOf, MotionBlur, MedianBlur, Blur, ShiftScaleRotate, OpticalDistortion, ElasticTransform, RandomBrightnessContrast, ) self.aug = Compose([ ToFloat(), OneOf([ MotionBlur(p=0.2), MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2), OneOf([ OpticalDistortion(p=0.3), ElasticTransform(p=0.1), ], p=0.2), ], p=0.5) self.im_mode = '1' def add(self, image: Union[str, Image.Image], *args, **kwargs) -> None: """ Adds a line-image-text pair to the dataset. Args: image (str): Input image path """ with open(self.split(image), 'r', encoding='utf-8') as fp: gt = fp.read().strip('\n\r') for func in self.text_transforms: gt = func(gt) if not gt: raise KrakenInputException('Text line is empty ({})'.format(fp.name)) if self.preload: try: im = Image.open(image) im = self.head_transforms(im) if not is_bitonal(im): self.im_mode = im.mode im = self.tail_transforms(im) except ValueError: raise KrakenInputException('Image transforms failed on {}'.format(image)) self._images.append(im) else: self._images.append(image) self._gt.append(gt) self.alphabet.update(gt) def add_loaded(self, image: Image.Image, gt: str) -> None: """ Adds an already loaded line-image-text pair to the dataset. Args: image (PIL.Image.Image): Line image gt (str): Text contained in the line image """ if self.preload: try: im = self.head_transforms(im) if not is_bitonal(im): self.im_mode = im.mode im = self.tail_transforms(im) except ValueError: raise KrakenInputException('Image transforms failed on {}'.format(image)) self._images.append(im) else: self._images.append(image) for func in self.text_transforms: gt = func(gt) self._gt.append(gt) self.alphabet.update(gt) def encode(self, codec: Optional[PytorchCodec] = None) -> None: """ Adds a codec to the dataset and encodes all text lines. Has to be run before sampling from the dataset. """ if codec: self.codec = codec else: self.codec = PytorchCodec(''.join(self.alphabet.keys())) self.training_set = [] # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]] for im, gt in zip(self._images, self._gt): self.training_set.append((im, self.codec.encode(gt))) def no_encode(self) -> None: """ Creates an unencoded dataset. """ self.training_set = [] # type: List[Tuple[Union[Image, torch.Tensor], str]] for im, gt in zip(self._images, self._gt): self.training_set.append((im, gt)) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: if self.preload: x, y = self.training_set[index] if self.aug: im = x.permute((1, 2, 0)).numpy() o = self.aug(image=im) im = torch.tensor(o['image'].transpose(2, 0, 1)) return {'image': im, 'target': y} return {'image': x, 'target': y} else: item = self.training_set[index] try: logger.debug('Attempting to load {}'.format(item[0])) im = item[0] if not isinstance(im, Image.Image): im = Image.open(im) im = self.head_transforms(im) if not is_bitonal(im): self.im_mode = im.mode im = self.tail_transforms(im) if self.aug: im = im.permute((1, 2, 0)).numpy() o = self.aug(image=im) im = torch.tensor(o['image'].transpose(2, 0, 1)) return {'image': im, 'target': item[1]} except Exception: idx = np.random.randint(0, len(self.training_set)) logger.debug('Failed. Replacing with sample {}'.format(idx)) return self[np.random.randint(0, len(self.training_set))] def __len__(self) -> int: return len(self.training_set) class BaselineSet(Dataset): """ Dataset for training a baseline/region segmentation model. """ def __init__(self, imgs: Sequence[str] = None, suffix: str = '.path', line_width: int = 4, im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]), mode: str = 'path', augmentation: bool = False, valid_baselines: Sequence[str] = None, merge_baselines: Dict[str, Sequence[str]] = None, valid_regions: Sequence[str] = None, merge_regions: Dict[str, Sequence[str]] = None): """ Reads a list of image-json pairs and creates a data set. Args: imgs (list): suffix (int): Suffix to attach to image base name to load JSON files from. line_width (int): Height of the baseline in the scaled input. target_size (tuple): Target size of the image as a (height, width) tuple. mode (str): Either path, alto, page, xml, or None. In alto, page, and xml mode the baseline paths and image data is retrieved from an ALTO/PageXML file. In `None` mode data is iteratively added through the `add` method. augmentation (bool): Enable/disable augmentation. valid_baselines (list): Sequence of valid baseline identifiers. If `None` all are valid. merge_baselines (dict): Sequence of baseline identifiers to merge. Note that merging occurs after entities not in valid_* have been discarded. valid_regions (list): Sequence of valid region identifiers. If `None` all are valid. merge_regions (dict): Sequence of region identifiers to merge. Note that merging occurs after entities not in valid_* have been discarded. """ super().__init__() self.mode = mode self.im_mode = '1' self.aug = None self.targets = [] # n-th entry contains semantic of n-th class self.class_mapping = {'aux': {'_start_separator': 0, '_end_separator': 1}, 'baselines': {}, 'regions': {}} self.num_classes = 2 self.mbl_dict = merge_baselines if merge_baselines is not None else {} self.mreg_dict = merge_regions if merge_regions is not None else {} self.valid_baselines = valid_baselines self.valid_regions = valid_regions if mode in ['alto', 'page', 'xml']: if mode == 'alto': fn = parse_alto elif mode == 'page': fn = parse_page elif mode == 'xml': fn = parse_xml im_paths = [] self.targets = [] for img in imgs: try: data = fn(img) im_paths.append(data['image']) lines = defaultdict(list) for line in data['lines']: if valid_baselines is None or line['script'] in valid_baselines: lines[self.mbl_dict.get(line['script'], line['script'])].append(line['baseline']) regions = defaultdict(list) for k, v in data['regions'].items(): if valid_regions is None or k in valid_regions: regions[self.mreg_dict.get(k, k)].extend(v) data['regions'] = regions self.targets.append({'baselines': lines, 'regions': data['regions']}) except KrakenInputException as e: logger.warning(e) continue # get line types imgs = im_paths # calculate class mapping line_types = set() region_types = set() for page in self.targets: for line_type in page['baselines'].keys(): line_types.add(line_type) for reg_type in page['regions'].keys(): region_types.add(reg_type) idx = -1 for idx, line_type in enumerate(line_types): self.class_mapping['baselines'][line_type] = idx + self.num_classes self.num_classes += idx + 1 idx = -1 for idx, reg_type in enumerate(region_types): self.class_mapping['regions'][reg_type] = idx + self.num_classes self.num_classes += idx + 1 elif mode == 'path': pass elif mode is None: imgs = [] else: raise Exception('invalid dataset mode') if augmentation: from albumentations import ( Compose, ToFloat, FromFloat, RandomRotate90, Flip, OneOf, MotionBlur, MedianBlur, Blur, ShiftScaleRotate, OpticalDistortion, ElasticTransform, RandomBrightnessContrast, HueSaturationValue, ) self.aug = Compose([ ToFloat(), RandomRotate90(), Flip(), OneOf([ MotionBlur(p=0.2), MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2), OneOf([ OpticalDistortion(p=0.3), ElasticTransform(p=0.1), ], p=0.2), HueSaturationValue(hue_shift_limit=20, sat_shift_limit=0.1, val_shift_limit=0.1, p=0.3), ], p=0.5) self.imgs = imgs self.line_width = line_width # split image transforms into two. one part giving the final PIL image # before conversion to a tensor and the actual tensor conversion part. self.head_transforms = transforms.Compose(im_transforms.transforms[:2]) self.tail_transforms = transforms.Compose(im_transforms.transforms[2:]) self.seg_type = None def add(self, image: Union[str, Image.Image], baselines: List[List[List[Tuple[int, int]]]] = None, regions: Dict[str, List[List[Tuple[int, int]]]] = None, *args, **kwargs): """ Adds a page to the dataset. Args: im (path): Path to the whole page image baseline (dict): A list containing dicts with a list of coordinates and script types [{'baseline': [[x0, y0], ..., [xn, yn]], 'script': 'script_type'}, ...] regions (dict): A dict containing list of lists of coordinates {'region_type_0': [[x0, y0], ..., [xn, yn]]], 'region_type_1': ...}. """ if self.mode: raise Exception('The `add` method is incompatible with dataset mode {}'.format(self.mode)) baselines_ = defaultdict(list) for line in baselines: line_type = self.mbl_dict.get(line['script'], line['script']) if self.valid_baselines is None or line['script'] in self.valid_baselines: baselines_[line_type].append(line['baseline']) if line_type not in self.class_mapping['baselines']: self.num_classes += 1 self.class_mapping['baselines'][line_type] = self.num_classes - 1 regions_ = defaultdict(list) for k, v in regions.items(): reg_type = self.mreg_dict.get(k, k) if self.valid_regions is None or reg_type in self.valid_regions: regions_[reg_type].extend(v) if reg_type not in self.class_mapping['regions']: self.num_classes += 1 self.class_mapping['regions'][reg_type] = self.num_classes - 1 self.targets.append({'baselines': baselines_, 'regions': regions_}) self.imgs.append(image) def __getitem__(self, idx): im = self.imgs[idx] if self.mode != 'path': target = self.targets[idx] else: with open('{}.path'.format(path.splitext(im)[0]), 'r') as fp: target = json.load(fp) if not isinstance(im, Image.Image): try: logger.debug('Attempting to load {}'.format(im)) im = Image.open(im) im, target = self.transform(im, target) return {'image': im, 'target': target} except Exception: idx = np.random.randint(0, len(self.imgs)) logger.debug('Failed. Replacing with sample {}'.format(idx)) return self[np.random.randint(0, len(self.imgs))] im, target = self.transform(im, target) return {'image': im, 'target': target} @staticmethod def _get_ortho_line(lineseg, point, line_width, offset): lineseg = np.array(lineseg) norm_vec = lineseg[1,...] - lineseg[0,...] norm_vec_len = np.sqrt(np.sum(norm_vec**2)) unit_vec = norm_vec / norm_vec_len ortho_vec = unit_vec[::-1] * ((1,-1), (-1,1)) if offset == 'l': point -= unit_vec * line_width else: point += unit_vec * line_width return (ortho_vec * 10 + point).astype('int').tolist() def transform(self, image, target): orig_size = image.size image = self.head_transforms(image) if not is_bitonal(image): self.im_mode = image.mode image = self.tail_transforms(image) scale = image.shape[2]/orig_size[0] t = torch.zeros((self.num_classes,) + image.shape[1:]) start_sep_cls = self.class_mapping['aux']['_start_separator'] end_sep_cls = self.class_mapping['aux']['_end_separator'] for key, lines in target['baselines'].items(): cls_idx = self.class_mapping['baselines'][key] for line in lines: # buffer out line to desired width line = [k for k, g in groupby(line)] line = np.array(line)*scale line_pol = np.array(geom.LineString(line).buffer(self.line_width/2, cap_style=2).boundary, dtype=np.int) rr, cc = polygon(line_pol[:,1], line_pol[:,0], shape=image.shape[1:]) t[cls_idx, rr, cc] = 1 start_sep = np.array(geom.LineString(self._get_ortho_line(line[:2], line[0], self.line_width/2, 'l')).buffer(self.line_width/2, cap_style=2).boundary, dtype=np.int) rr, cc = polygon(start_sep[:,1], start_sep[:,0], shape=image.shape[1:]) t[start_sep_cls, rr, cc] = 1 end_sep = np.array(geom.LineString(self._get_ortho_line(line[-2:], line[-1], self.line_width/2, 'r')).buffer(self.line_width/2, cap_style=2).boundary, dtype=np.int) rr, cc = polygon(end_sep[:,1], end_sep[:,0], shape=image.shape[1:]) t[end_sep_cls, rr, cc] = 1 for key, regions in target['regions'].items(): cls_idx = self.class_mapping['regions'][key] for region in regions: region = np.array(region)*scale rr, cc = polygon(region[:,1], region[:,0], shape=image.shape[1:]) t[cls_idx, rr, cc] = 1 target = t if self.aug: image = image.permute(1, 2, 0).numpy() target = target.permute(1, 2, 0).numpy() o = self.aug(image=image, mask=target) image = torch.tensor(o['image']).permute(2, 0, 1) target = torch.tensor(o['mask']).permute(2, 0, 1) return image, target def __len__(self): return len(self.imgs)