# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import unittest
from lightgbm import LGBMClassifier, LGBMRegressor
from skl2onnx import update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_classifier_output_shapes,  # noqa
    calculate_linear_regressor_output_shapes,
)
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
    convert_lightgbm  # noqa
)

try:
    from test_utils import dump_single_regression
except ImportError:
    import os
    import sys
    sys.path.append(
        os.path.join(
            os.path.dirname(__file__), "..", "tests"))
    from test_utils import dump_single_regression
from test_utils import dump_binary_classification, dump_multiple_classification


class TestLightGbmTreeEnsembleModels(unittest.TestCase):

    @classmethod
    def setUpClass(self):

        update_registered_converter(
            LGBMClassifier, 'LightGbmLGBMClassifier',
            calculate_linear_classifier_output_shapes,
            convert_lightgbm, options={
                'zipmap': [True, False], 'nocl': [True, False]})

        update_registered_converter(
            LGBMRegressor, 'LgbmRegressor',
            calculate_linear_regressor_output_shapes,
            convert_lightgbm)

    def test_lightgbm_classifier(self):
        model = LGBMClassifier(n_estimators=3, min_child_samples=1)
        dump_binary_classification(
            model, allow_failure="StrictVersion(onnx.__version__) < "
                                 "StrictVersion('1.3.0')")
        dump_multiple_classification(
            model,
            allow_failure="StrictVersion(onnx.__version__) < "
                          "StrictVersion('1.3.0')")

    def test_lightgbm_regressor(self):
        model = LGBMRegressor(n_estimators=3, min_child_samples=1)
        dump_single_regression(model)

    def test_lightgbm_regressor1(self):
        model = LGBMRegressor(n_estimators=1, min_child_samples=1)
        dump_single_regression(model, suffix="1")

    def test_lightgbm_regressor2(self):
        model = LGBMRegressor(n_estimators=2, max_depth=1, min_child_samples=1)
        dump_single_regression(model, suffix="2")


if __name__ == "__main__":
    unittest.main()