# ******************************************************************************
# Copyright 2018-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import onnx
from typing import List

from google.protobuf.message import DecodeError

from ngraph.impl import Function
from ngraph.impl import onnx_import
from ngraph.exceptions import UserInputError


def import_onnx_model(onnx_protobuf):  # type: (onnx.ModelProto) -> List[Function]
    """
    Import an ONNX Protocol Buffers model and convert it into a list of ngraph Functions.

    :param onnx_protobuf: ONNX Protocol Buffers model (onnx_pb2.ModelProto object)
    :return: list of ngraph Functions representing computations for each output.
    """
    if not isinstance(onnx_protobuf, onnx.ModelProto):
        raise UserInputError('Input does not seem to be a properly formatted ONNX model.')

    return onnx_import.import_onnx_model(onnx_protobuf.SerializeToString())


def import_onnx_file(filename):  # type: (str) -> List[Function]
    """
    Import ONNX model from a Protocol Buffers file and convert to ngraph functions.

    :param filename: path to an ONNX file
    :return: List of imported ngraph Functions (see docs for import_onnx_model).
    """
    try:
        onnx_protobuf = onnx.load(filename)
    except DecodeError:
        raise UserInputError("The provided file doesn't contain a properly formatted ONNX model.")

    return onnx_import.import_onnx_model(onnx_protobuf.SerializeToString())