# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
#
# Based on:
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
bn_helper is to maintain "faithful" bn stats (not "running" average stats)
during training. It computes the true mean/std on a
sufficiently large training batch, which is then used for test/val.
"faithful" bn stats are more reliable than "running" stats when we monitor
the val curves during training, but it often does not improve final results.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import logging
import numpy as np

from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, core
from core.config import config as cfg

from models import model_builder_video
from utils.timer import Timer
import utils.misc as misc


logger = logging.getLogger(__name__)


class BatchNormHelper():

    def __init__(self):
        self._model = None
        self._bn_layers = None

        self._meanX_dict = {}  # Replace "rm".
        self._meanX2_dict = {}

        self._var_dict = {}  # Replace "riv".

        self._last_update_iter = -1  # Log the last update iter.

    def create_bn_aux_model(self, node_id):
        """
        bn_aux_model:
        1. It is like "train", as it uses training data.
        2. It is like "train", as only the "train" mode of bn returns sm/siv
        (sm/siv: the mean and inverse *std* of the current batch).
        3. It is like "val/test", as it does not backprop and does not update.
        4. Note: "rm/riv" is fully irrelevant in bn_aux_model.
        """
        self._model = model_builder_video.ModelBuilder(
            name='{}_bn_aux'.format(cfg.MODEL.MODEL_NAME),
            train=True,
            use_cudnn=True,
            cudnn_exhaustive_search=True,
            ws_nbytes_limit=(cfg.CUDNN_WORKSPACE_LIMIT * 1024 * 1024),
            split=cfg.TRAIN.DATA_TYPE,
            use_mem_cache=False,  # We don't cache here.
            force_fw_only=True,
        )
        self._model.build_model(node_id=node_id)

        workspace.CreateNet(self._model.net)
        # self._model.start_data_loader()

        misc.save_net_proto(self._model.net)

        self._find_bn_layers()
        self._clean_and_reset_buffer()
        return

    def compute_and_update_bn_stats(self, curr_iter=None):
        """
        We update BN before: (i) testing and (ii) checkpointing.
        They may have different periods.
        To ensure test results are reproducible (not changed by new BN stats),
        We only compute new stats if curr_iter changes.
        """
        if curr_iter is None or curr_iter != self._last_update_iter:
            logger.info('Computing and updating BN stats at iter: {}'.format(
                curr_iter + 1))

            self._last_update_iter = curr_iter
            self._clean_and_reset_buffer()

            timer = Timer()
            for i in range(cfg.TRAIN.ITER_COMPUTE_PRECISE_BN):
                timer.tic()
                workspace.RunNet(self._model.net.Proto().name)
                self._collect_bn_stats()
                timer.toc()

                if (i + 1) % cfg.LOG_PERIOD == 0:
                    logger.info('Computing BN [{}/{}]: {:.3}s'.format(
                        i + 1, cfg.TRAIN.ITER_COMPUTE_PRECISE_BN, timer.diff))

            self._finalize_bn_stats()
            self._update_bn_stats_gpu()
        else:
            logger.info('BN of iter {} computed. Update to GPU only.'.format(
                curr_iter + 1))
            self._update_bn_stats_gpu()

    def _find_bn_layers(self):
        self._bn_layers = []
        for blob in self._model.params:
            blob = misc.unscope_name(str(blob))
            if blob.endswith('_bn_s'):
                bn_layer = blob[:-5]
                if bn_layer not in self._bn_layers:
                    self._bn_layers.append(bn_layer)

    def _clean_and_reset_buffer(self):
        self._meanX_dict = {}
        self._meanX2_dict = {}
        for bn_layer in self._bn_layers:
            self._meanX_dict[bn_layer] = 0
            self._meanX2_dict[bn_layer] = 0

    def _collect_bn_stats(self):
        """
        # let
        x = workspace.FetchBlob(layername)
        x = x.transpose((1, 0, 2, 3))
        x = x.reshape((x.shape[0], -1))
        # then:
        # sm == np.mean(x, axis=1)
        # siv == 1. / np.sqrt(np.var(x, axis=1) + cfg.MODEL.BN_EPSILON)

        We maintain meanX and meanX2 (X2 = X**2) which are additive.
        """
        bn_eps = cfg.MODEL.BN_EPSILON
        num_gpus = cfg.NUM_GPUS
        root_gpu_id = cfg.ROOT_GPU_ID
        for i in range(root_gpu_id, root_gpu_id + num_gpus):
            for bn_layer in self._bn_layers:
                layername = 'gpu_{}/'.format(i) + bn_layer

                single_batch_meanX = workspace.FetchBlob(
                    layername + '_bn_sm')
                single_batch_inv_std = workspace.FetchBlob(
                    layername + '_bn_siv')

                single_batch_var = (1. / single_batch_inv_std) ** 2 - bn_eps
                # var = mean(x ** 2) - mean(x) ** 2
                # np.mean(x ** 2, axis=1) - np.mean(x, axis=1) ** 2
                single_batch_meanX2 = \
                    single_batch_var + single_batch_meanX ** 2

                self._meanX_dict[bn_layer] += single_batch_meanX
                self._meanX2_dict[bn_layer] += single_batch_meanX2

    def _finalize_bn_stats(self):
        """Update the CPU cache."""

        normalize = cfg.TRAIN.ITER_COMPUTE_PRECISE_BN * cfg.NUM_GPUS

        self._var_dict = {}

        for bn_layer in self._bn_layers:
            self._meanX_dict[bn_layer] /= normalize
            self._meanX2_dict[bn_layer] /= normalize

            var = self._meanX2_dict[bn_layer] - self._meanX_dict[bn_layer] ** 2

            assert (var > 0.).all(), "layer: {} var < 0".format(bn_layer)
            self._var_dict[bn_layer] = var

    def _update_bn_stats_gpu(self):
        """
        Copy to GPU.
        Note: the actual blobs used at test time are "rm" and "riv"
        """

        num_gpus = cfg.NUM_GPUS
        root_gpu_id = cfg.ROOT_GPU_ID
        for i in range(root_gpu_id, root_gpu_id + num_gpus):
            with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, i)):
                for bn_layer in self._bn_layers:
                    workspace.FeedBlob(
                        'gpu_{}/'.format(i) + bn_layer + '_bn_rm',
                        np.array(self._meanX_dict[bn_layer], dtype=np.float32),
                    )
                    """
                    Note: riv is acutally running var (not running inv var)!!!!
                    """
                    workspace.FeedBlob(
                        'gpu_{}/'.format(i) + bn_layer + '_bn_riv',
                        np.array(self._var_dict[bn_layer], dtype=np.float32),
                    )