# Copyright 2015 Internap.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import re
from functools import wraps

from flask import request
from netaddr import IPNetwork, AddrFormatError, IPAddress

from netman.api.api_utils import BadRequest, MultiContext
from netman.core.objects.access_groups import IN, OUT
from netman.core.objects.exceptions import UnknownResource, BadVlanNumber, \
    BadVlanName, BadBondNumber, BadBondLinkSpeed, MalformedSwitchSessionRequest, \
    BadVrrpGroupNumber
from netman.core.objects.unicast_rpf_modes import STRICT


def resource(*validators):

    def resource_decorator(fn):
        @wraps(fn)
        def wrapper(self, **kwargs):
            with MultiContext(self, kwargs, *validators) as ctxs:
                return fn(self, *ctxs, **kwargs)

        return wrapper

    return resource_decorator


def content(validator_fn):

    def content_decorator(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            kwargs.update(validator_fn(request.data))
            return fn(*args, **kwargs)

        return wrapper

    return content_decorator


class Vlan:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.vlan = None

    def process(self, parameters):
        self.vlan = is_vlan_number(parameters.pop('vlan_number'))['vlan_number']

    def __enter__(self):
        return self.vlan

    def __exit__(self, *_):
        pass


class Bond:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.bond = None

    def process(self, parameters):
        self.bond = is_bond_number(parameters.pop('bond_number'))['bond_number']

    def __enter__(self):
        return self.bond

    def __exit__(self, *_):
        pass


class IPNetworkResource:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.ip_network = None

    def process(self, parameters):
        try:
            self.ip_network = is_ip_network(parameters.pop('ip_network'))['validated_ip_network']
        except BadRequest:
            raise BadRequest('Malformed IP, should be : x.x.x.x/xx')

    def __enter__(self):
        return self.ip_network

    def __exit__(self, *_):
        pass


class Switch:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.is_session = False
        self.switch = None

    def process(self, parameters):
        hostname = parameters.pop('hostname')
        try:
            self.switch = self.switch_api.resolve_session(hostname)
            self.is_session = True
        except UnknownResource:
            self.switch = self.switch_api.resolve_switch(hostname)

    def __enter__(self):
        if not self.is_session:
            self.switch.connect()
        return self.switch

    def __exit__(self, *_):
        if not self.is_session:
            self.switch.disconnect()


class Session:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.session = None

    def process(self, parameters):
        self.session = parameters.pop('session_id')
        self.switch_api.resolve_session(self.session)

    def __enter__(self):
        return self.session

    def __exit__(self, *_):
        pass


class Interface:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.interface = None

    def process(self, parameters):
        self.interface = parameters.pop('interface_id')

    def __enter__(self):
        return self.interface

    def __exit__(self, *_):
        pass


class Resource:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.resource = None

    def process(self, parameters):
        self.resource = parameters.pop('resource')

    def __enter__(self):
        return self.resource

    def __exit__(self, *_):
        pass


class Direction:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.direction = None

    def process(self, parameters):
        direction = parameters.pop('direction')
        if direction.lower() == 'in':
            self.direction = IN
        elif direction.lower() == 'out':
            self.direction = OUT
        else:
            raise UnknownResource("Unknown direction : {}".format(direction))

    def __enter__(self):
        return self.direction

    def __exit__(self, *_):
        pass


class VrrpGroup:
    def __init__(self, switch_api):
        self.switch_api = switch_api
        self.vrrp_group_id = None

    def process(self, parameters):
        try:
            self.vrrp_group_id = int(parameters.pop('vrrp_group_id'))
            if not 1 <= self.vrrp_group_id <= 255:
                raise BadVrrpGroupNumber()
        except (ValueError, KeyError):
            raise BadVrrpGroupNumber()

    def __enter__(self):
        return self.vrrp_group_id

    def __exit__(self, *_):
        pass


def is_session(data, **_):
    try:
        json_data = json.loads(data)
    except ValueError:
        raise BadRequest("Malformed content, should be a JSON object")

    if "hostname" not in json_data:
        raise MalformedSwitchSessionRequest()
    return {
        'hostname': json_data["hostname"]
    }


def is_vlan(data, **_):
    try:
        json_data = json.loads(data)
    except ValueError:
        raise BadRequest("Malformed content, should be a JSON object")

    if "number" not in json_data:
        raise BadVlanNumber()

    name = json_data["name"] if "name" in json_data and len(json_data["name"]) > 0 else None
    if name and " " in name:
        raise BadVlanName()

    return {
        'number': is_vlan_number(json_data["number"])['vlan_number'],
        'name': name
    }


def is_vlan_number(vlan_number, **_):
    try:
        vlan_int = int(vlan_number)
    except ValueError:
        logging.getLogger("netman.api").info("Rejected vlan content : {}".format(repr(vlan_number)))
        raise BadVlanNumber()

    if not 1 <= vlan_int <= 4094:
        logging.getLogger("netman.api").info("Rejected vlan number : {}".format(vlan_number))
        raise BadVlanNumber()

    return {'vlan_number': vlan_int}


def is_ip_network(data, **_):
    try:
        try:
            json_addr = json.loads(data)
            ip = IPNetwork("{}/{}".format(json_addr["address"], json_addr["mask"]))
        except ValueError:
            ip = IPNetwork(data)
    except (KeyError, AddrFormatError):
        raise BadRequest('Malformed content, should be : x.x.x.x/xx or {"address": "x.x.x.x", "mask": "xx"}')

    return {'validated_ip_network': ip}


def is_vrrp_group(data, **_):
    try:
        data = json.loads(data)
    except ValueError:
        raise BadRequest("Malformed content, should be a JSON object")

    if data.get('id') is None:
        raise BadRequest("VRRP group id is mandatory")

    return dict(
        group_id=data.pop('id'),
        ips=[validate_ip_address(i) for i in data.pop('ips', [])],
        **data
    )


def is_int(number, **_):
    try:
        value = int(number)
    except ValueError:
        raise BadRequest('Expected integer content, got "{}"'.format(number))

    return {'value': value}


def is_boolean(option, **_):
    option = option.lower()
    if option not in ['true', 'false']:
        raise BadRequest('Unreadable content "{}". Should be either "true" or "false"'.format(option))

    return {'state': option == 'true'}


def is_access_group_name(data, **_):
    if data == "" or " " in data:
        raise BadRequest('Malformed access group name')

    return {'access_group_name': data}


def is_vrf_name(data, **_):
    if data == "" or " " in data:
        raise BadRequest('Malformed VRF name')

    return {'vrf_name': data}


def is_bond_number(bond_number, **_):
    try:
        bond_number_int = int(bond_number)
    except ValueError:
        logging.getLogger("netman.api").info("Rejected number content : {}".format(repr(bond_number)))
        raise BadBondNumber()

    return {'bond_number': bond_number_int}


def is_bond(data, **_):
    try:
        json_data = json.loads(data)
    except ValueError:
        raise BadRequest("Malformed content, should be a JSON object")

    if "number" not in json_data:
        raise BadBondNumber()

    return {
        'bond_number': is_bond_number(json_data["number"])['bond_number'],
    }


def is_unincast_rpf_mode(data, **_):
    if data not in [STRICT]:
        raise BadRequest('Invalid unicast rpf mode')

    return {'mode': data}


def is_bond_link_speed(data, **_):
    if re.match(r'^\d+[mg]$', data):
        return {'bond_link_speed': data}

    raise BadBondLinkSpeed()


def is_description(description, **_):
    return {'description': description}


def is_dict_with(**fields):
    def m(data, **_):
        try:
            result = json.loads(data)
        except ValueError:
            raise BadRequest("Malformed JSON request")
        for field, validator in fields.iteritems():
            validator(result, field)
        for field, validator in result.iteritems():
            if field not in fields:
                raise BadRequest("Unknown key: {}".format(field))
        return result

    return m


def validate_ip_address(data):
    try:
        return IPAddress(data)
    except:
        raise BadRequest("Incorrect IP Address: \"{}\", should be x.x.x.x".format(data))


def optional(sub_validator):
    def m(params, key):
        if key in params:
            sub_validator(params, key)
    return m


def is_type(obj_type):
    def m(params, key):
        if not isinstance(params[key], obj_type):
            raise BadRequest('Expected "{}" type for key {}, got "{}"'.format(obj_type.__name__, key, type(params[key]).__name__))
    return m