import chainer import chainer.functions as F import chainer.links as L from chainer import Chain from chainer.backends import cuda from resnet.resnet_gn import ResNet from train_utils.autocopy import maybe_copy class TextRecognizer(Chain): def __init__(self, num_chars, num_classes, **kwargs): super().__init__() with self.init_scope(): self.feature_extractor = ResNet(kwargs.pop('num_layers', 18)) # self.lstm = L.LSTM(None, 1024) self.classifier = L.Linear(None, num_classes) self.num_chars = num_chars chainer.global_config.user_text_recognition_grayscale_input = False @maybe_copy def __call__(self, rois): batch_size, num_bboxes, num_channels, height, width = rois.shape rois = F.reshape(rois, (-1, num_channels, height, width)) # if not chainer.config.user_text_recognition_grayscale_input: # # convert data to grayscale # assert rois.shape[1] == 3, "rois are not in RGB, can not convert them to grayscale" # r, g, b = F.separate(rois, axis=1) # grey = 0.299 * r + 0.587 * g + 0.114 * b # rois = F.stack([grey, grey, grey], axis=1) h = self.feature_extractor(rois) _, num_channels, feature_height, feature_width = h.shape h = F.average_pooling_2d(h, (feature_height, feature_width)) h = F.reshape(h, (batch_size, num_bboxes, num_channels, -1)) all_predictions = [] for box in F.separate(h, axis=1): # box_predictions = [self.classifier(self.lstm(box)) for _ in range(self.num_chars)] box_predictions = [self.classifier(box) for _ in range(self.num_chars)] all_predictions.append(F.stack(box_predictions, axis=1)) # return shape: batch_size, num_bboxes, num_chars, num_classes return F.stack(all_predictions, axis=2) def calc_loss(self, predictions, labels): recognition_losses = [] assert predictions.shape[1] == labels.shape[1], "Number of boxes is not equal in predictions and labels" for box, box_labels in zip(F.separate(predictions, axis=1), F.separate(labels, axis=1)): assert box.shape[1] == box_labels.shape[1], "Number of predicted chars is not equal to number of chars in label" box_losses = [ F.softmax_cross_entropy(char, char_label, reduce="no") for char, char_label in zip(F.separate(box, axis=1), F.separate(box_labels, axis=1)) ] recognition_losses.append(F.stack(box_losses)) return F.mean(F.stack(recognition_losses)) def decode_prediction(self, prediction): words = [] for box in F.separate(prediction, axis=1): word = [F.argmax(F.softmax(character), axis=1) for character in F.separate(box, axis=1)] words.append(F.stack(word, axis=1)) return F.stack(words, axis=1) @maybe_copy def predict(self, images, return_visual_backprop=False): if isinstance(images, list): images = [self.xp.asarray(image) for image in images] images = self.xp.stack(images, axis=0) with chainer.using_config('train', False): text_recognition_result = self(images) prediction = self.decode_prediction(text_recognition_result) return prediction