import argparse
import io

import mmcv
import onnx
import torch
from mmcv.runner import load_checkpoint
from onnx import optimizer
from torch.onnx import OperatorExportTypes

from mmdet.models import build_detector
from mmdet.ops import RoIAlign, RoIPool


def export_onnx_model(model, inputs, passes):
    """Trace and export a model to onnx format. Modified from
    https://github.com/facebookresearch/detectron2/

    Args:
        model (nn.Module):
        inputs (tuple[args]): the model will be called by `model(*inputs)`
        passes (None or list[str]): the optimization passed for ONNX model

    Returns:
        an onnx model
    """
    assert isinstance(model, torch.nn.Module)

    # make sure all modules are in eval mode, onnx may change the training
    # state of the module if the states are not consistent
    def _check_eval(module):
        assert not module.training

    model.apply(_check_eval)

    # Export the model to ONNX
    with torch.no_grad():
        with io.BytesIO() as f:
            torch.onnx.export(
                model,
                inputs,
                f,
                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                # verbose=True,  # NOTE: uncomment this for debugging
                # export_params=True,
            )
            onnx_model = onnx.load_from_string(f.getvalue())

    # Apply ONNX's Optimization
    if passes is not None:
        all_passes = optimizer.get_available_passes()
        assert all(p in all_passes for p in passes), \
            f'Only {all_passes} are supported'
    onnx_model = optimizer.optimize(onnx_model, passes)
    return onnx_model


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet pytorch model conversion to ONNX')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--out', type=str, required=True, help='output ONNX filename')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[1280, 800],
        help='input image size')
    parser.add_argument(
        '--passes', type=str, nargs='+', help='ONNX optimization passes')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    if not args.out.endswith('.onnx'):
        raise ValueError('The output file must be a onnx file.')

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None

    # build the model and load checkpoint
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
    load_checkpoint(model, args.checkpoint, map_location='cpu')
    # Only support CPU mode for now
    model.cpu().eval()
    # Customized ops are not supported, use torchvision ops instead.
    for m in model.modules():
        if isinstance(m, (RoIPool, RoIAlign)):
            # set use_torchvision on-the-fly
            m.use_torchvision = True

    # TODO: a better way to override forward function
    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'ONNX conversion is currently not currently supported with '
            f'{model.__class__.__name__}')

    input_data = torch.empty((1, *input_shape),
                             dtype=next(model.parameters()).dtype,
                             device=next(model.parameters()).device)

    onnx_model = export_onnx_model(model, (input_data, ), args.passes)
    # Print a human readable representation of the graph
    onnx.helper.printable_graph(onnx_model.graph)
    print(f'saving model in {args.out}')
    onnx.save(onnx_model, args.out)


if __name__ == '__main__':
    main()