from chainer import Chain
from chainer import cuda

import chainer.functions as F
import chainer.links as L
from functions.disable_shearing import disable_shearing
from functions.disable_translation import disable_translation
from functions.rotation_droput import rotation_dropout
from insights.visual_backprop import VisualBackprop


class ResnetBlock(Chain):

    def __init__(self, num_filter, filter_increase=False, use_dropout=False, dropout_ratio=0.5):
        super().__init__()
        with self.init_scope():
            self.conv0 = L.Convolution2D(None, num_filter, 3, pad=1)
            self.bn0 = L.BatchNormalization(num_filter)
            self.conv1 = L.Convolution2D(num_filter, num_filter, 3, pad=1)
            self.bn1 = L.BatchNormalization(num_filter)
            self.use_dropout = use_dropout
            self.dropout_ratio = dropout_ratio

            if filter_increase:
                self.conv2 = L.Convolution2D(None, num_filter, 1)
                self.bn2 = L.BatchNormalization(num_filter)

        self.filter_increase = filter_increase
        self._train = True

    @property
    def train(self):
        return self._train

    @train.setter
    def train(self, value):
        self._train = value

    def __call__(self, x):
        h = self.bn0(self.conv0(x))
        if self.use_dropout:
            h = F.dropout(h, ratio=self.dropout_ratio)
        h = F.relu(h)
        h = self.bn1(self.conv1(h))
        if self.use_dropout:
            h = F.dropout(h, ratio=self.dropout_ratio)

        if self.filter_increase:
            h_pre = self.bn2(self.conv2(x))
            h = h + h_pre

        h = F.relu(h)
        return h


class FSNSMultipleSTNLocalizationNet(Chain):

    def __init__(self, dropout_factor, num_timesteps, zoom=0.9):
        super(FSNSMultipleSTNLocalizationNet, self).__init__()
        with self.init_scope():
            self.conv0 = L.Convolution2D(None, 32, 3, pad=1)
            self.bn0 = L.BatchNormalization(32)
            self.rs1 = ResnetBlock(32)
            self.rs2 = ResnetBlock(48, filter_increase=True)
            self.rs3 = ResnetBlock(48)
            self.lstm = L.LSTM(None, 256)
            self.translation_transform = L.Linear(256, 6)
            self.rotation_transform = L.Linear(256, 6)
            self.transform_2 = L.LSTM(256, 6)

        self.dropout_factor = dropout_factor
        self._train = True
        self.num_timesteps = num_timesteps

        for transform in [self.translation_transform, self.rotation_transform]:
            transform_bias = transform.b.data
            transform_bias[[0, 4]] = zoom
            transform_bias[[2, 5]] = 0
            transform.W.data[...] = 0

        # self.transform_2.upward.b.data[...] = 0
        # self.transform_2.upward.W.data[...] = 0
        # self.transform_2.lateral.W.data[...] = 0

        # self.transform.W.data[...] = 0

        self.visual_backprop = VisualBackprop()
        self.vis_anchor = None

    @property
    def train(self):
        return self._train

    @train.setter
    def train(self, value):
        self._train = value
        self.rs1.train = value
        self.rs2.train = value
        self.rs3.train = value

    def __call__(self, images):
        self.lstm.reset_state()
        self.transform_2.reset_state()

        h = self.bn0(self.conv0(images))
        h = F.average_pooling_2d(F.relu(h), 2, stride=2)

        h = self.rs1(h)
        h = F.max_pooling_2d(h, 2, stride=2)

        h = self.rs2(h)
        h = F.max_pooling_2d(h, 2, stride=2)

        h = self.rs3(h)
        self.vis_anchor = h
        h = F.average_pooling_2d(h, 5, stride=2)

        localizations = []

        with cuda.get_device_from_array(h.data):
            homogenuous_addon = self.xp.zeros((len(h), 1, 3), dtype=h.data.dtype)
            homogenuous_addon[:, 0, 2] = 1

        for _ in range(self.num_timesteps):
            lstm_prediction = F.relu(self.lstm(h))
            translation_transform = F.reshape(self.rotation_transform(lstm_prediction), (-1, 2, 3))
            translation_transform = disable_shearing(translation_transform)
            translation_transform = F.concat((translation_transform, homogenuous_addon), axis=1)

            rotation_transform = F.reshape(self.rotation_transform(lstm_prediction), (-1, 2, 3))
            rotation_transform = disable_translation(rotation_transform)
            rotation_transform = F.concat((rotation_transform, homogenuous_addon), axis=1)

            # first rotate, then translate
            transform = F.batch_matmul(rotation_transform, translation_transform)
            # homogenuous_multiplier = F.get_item(transform, (..., 2, 2))
            #
            # # bring matrices from homogenous coordinates to normal coordinates
            transform = transform[:, :2, :]
            # transform = transform / homogenuous_multiplier
            localizations.append(rotation_dropout(transform, ratio=self.dropout_factor))

        return F.concat(localizations, axis=0)


class FSNSSingleSTNLocalizationNet(Chain):

    def __init__(self, dropout_ratio, num_timesteps, zoom=0.9, use_dropout=False):
        super(FSNSSingleSTNLocalizationNet, self).__init__()
        with self.init_scope():
            self.conv0 = L.Convolution2D(None, 32, 3, pad=1)
            self.bn0 = L.BatchNormalization(32)
            self.rs1 = ResnetBlock(32, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.rs2 = ResnetBlock(48, filter_increase=True, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.rs3 = ResnetBlock(48)
            # self.rs4 = ResnetBlock(16, filter_increase=True)
            self.lstm = L.LSTM(None, 256)
            self.transform_2 = L.LSTM(256, 6)

        self.dropout_ratio = dropout_ratio
        self.use_dropout = use_dropout
        self._train = True
        self.num_timesteps = num_timesteps

        # initialize transform
        # self.transform_2.W.data[...] = 0
        #
        # transform_bias = self.transform_2.b.data
        # transform_bias[[0, 4]] = zoom
        # transform_bias[[2, 5]] = 0

        self.visual_backprop = VisualBackprop()
        self.vis_anchor = None

        self.width_encoding = None
        self.height_encoding = None

    def __call__(self, images):
        self.lstm.reset_state()
        self.transform_2.reset_state()

        h = self.bn0(self.conv0(images))
        h = F.average_pooling_2d(F.relu(h), 2, stride=2)

        h = self.rs1(h)
        h = F.max_pooling_2d(h, 2, stride=2)

        h = self.rs2(h)
        h = F.max_pooling_2d(h, 2, stride=2)

        h = self.rs3(h)
        # h = self.rs4(h)
        self.vis_anchor = h
        h = F.average_pooling_2d(h, 5, stride=2)

        localizations = []

        with cuda.get_device_from_array(h.data):
            # lstm_prediction = chainer.Variable(self.xp.zeros((len(images), self.lstm.state_size), dtype=h.dtype))

            for _ in range(self.num_timesteps):
                # in_feature = self.attend(h, lstm_prediction)
                in_feature = h
                lstm_prediction = F.relu(self.lstm(in_feature))
                transformed = self.transform_2(lstm_prediction)
                transformed = F.reshape(transformed, (-1, 2, 3))
                localizations.append(rotation_dropout(transformed, ratio=self.dropout_ratio))

        return F.concat(localizations, axis=0)


class FSNSRecognitionNet(Chain):

    def __init__(self, target_shape, num_labels, num_timesteps, uses_original_data=False, dropout_ratio=0.5, use_dropout=False):
        super(FSNSRecognitionNet, self).__init__()
        with self.init_scope():
            self.conv0 = L.Convolution2D(None, 32, 3, pad=1, stride=2)
            self.bn0 = L.BatchNormalization(32)
            self.conv1 = L.Convolution2D(32, 32, 3, pad=1)
            self.bn1 = L.BatchNormalization(32)
            self.rs1 = ResnetBlock(32, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.rs2 = ResnetBlock(64, filter_increase=True, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.rs3 = ResnetBlock(128, filter_increase=True, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.fc1 = L.Linear(None, 256)
            self.lstm = L.LSTM(None, 256)
            self.classifier = L.Linear(None, 134)

        self._train = True
        self.target_shape = target_shape
        self.num_labels = num_labels
        self.num_timesteps = num_timesteps
        self.uses_original_data = uses_original_data
        self.vis_anchor = None
        self.use_dropout = use_dropout
        self.dropout_ratio = dropout_ratio

    def __call__(self, images, localizations):
        points = F.spatial_transformer_grid(localizations, self.target_shape)
        rois = F.spatial_transformer_sampler(images, points)

        h = F.relu(self.bn0(self.conv0(rois)))
        if self.use_dropout:
            h = F.dropout(h, ratio=self.dropout_ratio)
        h = F.relu(self.bn1(self.conv1(h)))
        if self.use_dropout:
            h = F.dropout(h, ratio=self.dropout_ratio)

        h = self.rs1(h)
        h = self.rs2(h)
        h = F.max_pooling_2d(h, 2, stride=2)
        h = self.rs3(h)
        self.vis_anchor = h

        h = F.average_pooling_2d(h, 5, stride=1)

        if self.uses_original_data:
            # merge data of all 4 individual images in channel dimension
            batch_size, num_channels, height, width = h.shape
            h = F.reshape(h, (batch_size // 4, 4 * num_channels, height, width))

        h = F.relu(self.fc1(h))

        # for each timestep of the localization net do the 'classification'
        h = F.reshape(h, (self.num_timesteps, -1, self.fc1.out_size))
        overall_predictions = []
        for timestep in F.separate(h, axis=0):
            # go 2x num_labels plus 1 timesteps because of ctc loss
            lstm_predictions = []
            self.lstm.reset_state()
            for _ in range(self.num_labels * 2 + 1):
                lstm_prediction = self.lstm(timestep)
                classified = self.classifier(lstm_prediction)
                lstm_predictions.append(classified)
            overall_predictions.append(lstm_predictions)

        return overall_predictions, rois, points


class FSNSSoftmaxRecognitionNet(Chain):

    def __init__(self, target_shape, num_labels, num_timesteps, uses_original_data=False, dropout_ratio=0.5, use_dropout=False, use_blstm=False):
        super().__init__()
        with self.init_scope():
            self.conv0 = L.Convolution2D(None, 32, 3, pad=1)
            self.bn0 = L.BatchNormalization(32)
            self.rs1 = ResnetBlock(32)
            self.rs2 = ResnetBlock(64, filter_increase=True, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.rs3 = ResnetBlock(128, filter_increase=True, use_dropout=use_dropout, dropout_ratio=dropout_ratio)
            self.fc1 = L.Linear(None, 256)
            self.lstm = L.LSTM(None, 256)
            if use_blstm:
                self.blstm = L.LSTM(None, 256)
            self.classifier = L.Linear(None, 134)

        self._train = True
        self.target_shape = target_shape
        self.num_labels = num_labels
        self.num_timesteps = num_timesteps
        self.uses_original_data = uses_original_data
        self.vis_anchor = None
        self.use_dropout = use_dropout
        self.dropout_ratio = dropout_ratio
        self.use_blstm = use_blstm

    def __call__(self, images, localizations):
        points = F.spatial_transformer_grid(localizations, self.target_shape)
        rois = F.spatial_transformer_sampler(images, points)

        h = self.bn0(self.conv0(rois))
        h = F.average_pooling_2d(F.relu(h), 2, stride=2)

        h = self.rs1(h)
        h = self.rs2(h)
        h = F.max_pooling_2d(h, 2, stride=2)
        h = self.rs3(h)
        self.vis_anchor = h

        h = F.average_pooling_2d(h, 5, stride=1)

        if self.uses_original_data:
            # merge data of all 4 individual images in channel dimension
            batch_size, num_channels, height, width = h.shape
            h = F.reshape(h, (batch_size // 4, 4 * num_channels, height, width))

        h = F.relu(self.fc1(h))

        # for each timestep of the localization net do the 'classification'
        h = F.reshape(h, (self.num_timesteps, -1, self.fc1.out_size))
        overall_predictions = []
        for timestep in F.separate(h, axis=0):
            # go 2x num_labels plus 1 timesteps because of ctc loss
            lstm_predictions = []
            self.lstm.reset_state()
            if self.use_blstm:
                self.blstm.reset_state()

            for _ in range(self.num_labels):
                lstm_prediction = self.lstm(timestep)
                lstm_predictions.append(lstm_prediction)

            if self.use_blstm:
                blstm_predictions = []
                for lstm_prediction in reversed(lstm_predictions):
                    blstm_prediction = self.blstm(lstm_prediction)
                    blstm_predictions.append(blstm_prediction)

                lstm_predictions = reversed(blstm_predictions)

            final_lstm_predictions = []
            for lstm_prediction in lstm_predictions:
                classified = self.classifier(lstm_prediction)
                final_lstm_predictions.append(F.expand_dims(classified, axis=0))

            final_lstm_predictions = F.concat(final_lstm_predictions, axis=0)
            overall_predictions.append(final_lstm_predictions)

        return overall_predictions, rois, points


class FSNSSoftmaxRecognitionResNet(Chain):

    def __init__(self, target_shape, num_labels, num_timesteps, uses_original_data=False, dropout_ratio=0.5, use_dropout=False, use_blstm=False, use_attention=False):
        super().__init__()
        with self.init_scope():
            self.data_bn = L.BatchNormalization(3)
            self.conv0 = L.Convolution2D(None, 64, 7, stride=2, pad=3, nobias=True)
            self.bn0 = L.BatchNormalization(64)
            self.rs1_1 = ResnetBlock(64)
            self.rs1_2 = ResnetBlock(64)
            self.rs2_1 = ResnetBlock(128, filter_increase=True)
            self.rs2_2 = ResnetBlock(128)
            self.rs3_1 = ResnetBlock(256, filter_increase=True)
            self.rs3_2 = ResnetBlock(256)
            self.rs4_1 = ResnetBlock(512, filter_increase=True)
            self.rs4_2 = ResnetBlock(512)
            self.fc1 = L.Linear(None, 512)
            self.lstm = L.LSTM(None, 512)
            if use_blstm:
                self.blstm = L.LSTM(None, 512)
            if use_attention:
                self.transform_encoded_features = L.Linear(512, 512, nobias=True)
                self.transform_out_lstm_feature = L.Linear(512, 512, nobias=True)
                self.generate_attended_feat = L.Linear(512, 1)
                self.out_lstm = L.LSTM(512, 512)
            self.classifier = L.Linear(None, 134)

        self._train = True
        self.target_shape = target_shape
        self.num_labels = num_labels
        self.num_timesteps = num_timesteps
        self.uses_original_data = uses_original_data
        self.vis_anchor = None
        self.use_dropout = use_dropout
        self.dropout_ratio = dropout_ratio
        self.use_blstm = use_blstm
        self.use_attention = use_attention

    def attend(self, encoded_features):
        self.out_lstm.reset_state()
        transformed_encoded_features = F.concat([F.expand_dims(self.transform_encoded_features(feature), axis=1) for feature in encoded_features], axis=1)
        concat_encoded_features = F.concat([F.expand_dims(e, axis=1) for e in encoded_features], axis=1)

        lstm_output = self.xp.zeros_like(encoded_features[0])
        outputs = []
        for _ in range(self.num_labels):
            transformed_lstm_output = self.transform_out_lstm_feature(lstm_output)
            attended_feats = []
            for transformed_encoded_feature in F.separate(transformed_encoded_features, axis=1):
                attended_feat = transformed_encoded_feature + transformed_lstm_output
                attended_feat = F.tanh(attended_feat)
                attended_feats.append(self.generate_attended_feat(attended_feat))

            attended_feats = F.concat(attended_feats, axis=1)
            alphas = F.softmax(attended_feats, axis=1)

            lstm_input_feature = F.batch_matmul(alphas, concat_encoded_features, transa=True)
            lstm_input_feature = F.squeeze(lstm_input_feature, axis=1)
            lstm_output = self.out_lstm(lstm_input_feature)
            outputs.append(lstm_output)
        return outputs

    def __call__(self, images, localizations):
        points = F.spatial_transformer_grid(localizations, self.target_shape)
        rois = F.spatial_transformer_sampler(images, points)

        connected_rois = self.data_bn(rois)
        h = F.relu(self.bn0(self.conv0(connected_rois)))
        h = F.max_pooling_2d(h, 3, stride=2, pad=1)

        h = self.rs1_1(h)
        h = self.rs1_2(h)

        h = self.rs2_1(h)
        h = self.rs2_2(h)

        h = self.rs3_1(h)
        h = self.rs3_2(h)

        h = self.rs4_1(h)
        h = self.rs4_2(h)

        self.vis_anchor = h

        h = F.average_pooling_2d(h, 7, stride=1)

        if self.uses_original_data:
            # merge data of all 4 individual images in channel dimension
            batch_size, num_channels, height, width = h.shape
            h = F.reshape(h, (batch_size // 4, 4 * num_channels, height, width))

        h = F.relu(self.fc1(h))

        # for each timestep of the localization net do the 'classification'
        h = F.reshape(h, (self.num_timesteps, -1, self.fc1.out_size))
        overall_predictions = []
        for timestep in F.separate(h, axis=0):
            # go 2x num_labels plus 1 timesteps because of ctc loss
            lstm_predictions = []
            self.lstm.reset_state()
            if self.use_blstm:
                self.blstm.reset_state()

            for _ in range(self.num_labels):
                lstm_prediction = self.lstm(timestep)
                lstm_predictions.append(lstm_prediction)

            if self.use_blstm:
                blstm_predictions = []
                for lstm_prediction in reversed(lstm_predictions):
                    blstm_prediction = self.blstm(lstm_prediction)
                    blstm_predictions.append(blstm_prediction)

                lstm_predictions = list(reversed(blstm_predictions))

            if self.use_attention:
                lstm_predictions = self.attend(lstm_predictions)

            final_lstm_predictions = []
            for lstm_prediction in lstm_predictions:
                classified = self.classifier(lstm_prediction)
                final_lstm_predictions.append(F.expand_dims(classified, axis=0))

            final_lstm_predictions = F.concat(final_lstm_predictions, axis=0)
            overall_predictions.append(final_lstm_predictions)

        return overall_predictions, rois, points


class FSNSNet(Chain):

    def __init__(self, localization_net, recognition_net, uses_original_data=False):
        super(FSNSNet, self).__init__()
        with self.init_scope():
            self.localization_net = localization_net
            self.recognition_net = recognition_net

        self._train = True
        self.uses_original_data = uses_original_data

    @property
    def train(self):
        return self._train

    @train.setter
    def train(self, value):
        self._train = value
        self.localization_net.train = value
        self.recognition_net.train = value

    def __call__(self, images, label=None):
        if self.uses_original_data:
            # handle each individual view as increase in batch size
            batch_size, num_channels, height, width = images.shape
            images = F.reshape(images, (batch_size, num_channels, height, 4, -1))
            images = F.transpose(images, (0, 3, 1, 2, 4))
            images = F.reshape(images, (batch_size * 4, num_channels, height, width // 4))

        batch_size = images.shape[0]
        h = self.localization_net(images)
        new_batch_size = h.shape[0]
        batch_size_increase_factor = new_batch_size // batch_size
        images = F.concat([images for _ in range(batch_size_increase_factor)], axis=0)

        if label is None:
            return self.recognition_net(images, h)
        return self.recognition_net(images, h, label)


class FSNSResnetReuseNet(Chain):

    def __init__(self, target_shape, num_timesteps, num_labels, dropout_ratio=0.5, uses_original_data=False, use_blstm=False):
        super().__init__()
        with self.init_scope():
            self.data_bn = L.BatchNormalization(3)
            self.conv0 = L.Convolution2D(None, 64, 7, stride=2, pad=3, nobias=True)
            self.bn0 = L.BatchNormalization(64)
            self.rs1_1 = ResnetBlock(64)
            self.rs1_2 = ResnetBlock(64)
            self.rs2_1 = ResnetBlock(128, filter_increase=True)
            self.rs2_2 = ResnetBlock(128)
            self.rs3_1 = ResnetBlock(256, filter_increase=True)
            self.rs3_2 = ResnetBlock(256)
            # self.rs4_1 = ResnetBlock(512, filter_increase=True)
            # self.rs4_2 = ResnetBlock(512)

            # localization part
            self.lstm = L.LSTM(None, 256)
            self.transform_2 = L.LSTM(256, 6)

            # recognition part
            self.fc1 = L.Linear(None, 256)
            self.recognition_lstm = L.LSTM(None, 256)
            if use_blstm:
                self.recognition_blstm = L.LSTM(None, 256)
            self.classifier = L.Linear(None, 134)

        self.uses_original_data = uses_original_data
        self.use_blstm = use_blstm
        self.num_timesteps =num_timesteps
        self.num_labels = num_labels
        self.dropout_ratio = dropout_ratio
        self.target_shape = target_shape

        self.localization_vis_anchor = None
        self.recognition_vis_anchor = None

    def localization_net(self, images):
        self.lstm.reset_state()
        self.transform_2.reset_state()

        images = self.data_bn(images)
        h = F.relu(self.bn0(self.conv0(images)))
        h = F.max_pooling_2d(h, 3, stride=2, pad=1)

        h = self.rs1_1(h)
        h = self.rs1_2(h)

        h = self.rs2_1(h)
        h = self.rs2_2(h)

        h = self.rs3_1(h)
        h = self.rs3_2(h)

        # h = self.rs4_1(h)
        # h = self.rs4_2(h)

        self.localization_vis_anchor = h

        h = F.average_pooling_2d(h, 5, stride=1)

        localizations = []

        with cuda.get_device_from_array(h.data):
            for _ in range(self.num_timesteps):
                in_feature = h
                lstm_prediction = F.relu(self.lstm(in_feature))
                transformed = self.transform_2(lstm_prediction)
                transformed = F.reshape(transformed, (-1, 2, 3))
                localizations.append(rotation_dropout(transformed, ratio=self.dropout_ratio))

        return F.concat(localizations, axis=0)

    def recognition_net(self, images, localizations):
        points = F.spatial_transformer_grid(localizations, self.target_shape)
        rois = F.spatial_transformer_sampler(images, points)

        connected_rois = self.data_bn(rois)
        h = F.relu(self.bn0(self.conv0(connected_rois)))
        h = F.max_pooling_2d(h, 3, stride=2, pad=1)

        h = self.rs1_1(h)
        h = self.rs1_2(h)

        h = self.rs2_1(h)
        h = self.rs2_2(h)

        h = self.rs3_1(h)
        h = self.rs3_2(h)

        # h = self.rs4_1(h)
        # h = self.rs4_2(h)

        self.recognition_vis_anchor = h

        h = F.average_pooling_2d(h, 5, stride=1)

        if self.uses_original_data:
            # merge data of all 4 individual images in channel dimension
            batch_size, num_channels, height, width = h.shape
            h = F.reshape(h, (batch_size // 4, 4 * num_channels, height, width))

        h = F.relu(self.fc1(h))

        # for each timestep of the localization net do the 'classification'
        h = F.reshape(h, (self.num_timesteps, -1, self.fc1.out_size))
        overall_predictions = []
        for timestep in F.separate(h, axis=0):
            # go 2x num_labels plus 1 timesteps because of ctc loss
            lstm_predictions = []
            self.recognition_lstm.reset_state()
            if self.use_blstm:
                self.recognition_blstm.reset_state()

            for _ in range(self.num_labels):
                lstm_prediction = self.recognition_lstm(timestep)
                lstm_predictions.append(lstm_prediction)

            if self.use_blstm:
                blstm_predictions = []
                for lstm_prediction in reversed(lstm_predictions):
                    blstm_prediction = self.recognition_blstm(lstm_prediction)
                    blstm_predictions.append(blstm_prediction)

                lstm_predictions = reversed(blstm_predictions)

            final_lstm_predictions = []
            for lstm_prediction in lstm_predictions:
                classified = self.classifier(lstm_prediction)
                final_lstm_predictions.append(F.expand_dims(classified, axis=0))

            final_lstm_predictions = F.concat(final_lstm_predictions, axis=0)
            overall_predictions.append(final_lstm_predictions)

        return overall_predictions, rois, points

    def __call__(self, images):
        if self.uses_original_data:
            # handle each individual view as increase in batch size
            batch_size, num_channels, height, width = images.shape
            images = F.reshape(images, (batch_size, num_channels, height, 4, -1))
            images = F.transpose(images, (0, 3, 1, 2, 4))
            images = F.reshape(images, (batch_size * 4, num_channels, height, width // 4))

        batch_size = images.shape[0]
        localization = self.localization_net(images)
        new_batch_size = localization.shape[0]
        batch_size_increase_factor = new_batch_size // batch_size
        images = F.concat([images for _ in range(batch_size_increase_factor)], axis=0)
        return self.recognition_net(images, localization)