# -*- coding: utf-8 -*-

""" OneLogin_Saml2_Utils class

Copyright (c) 2010-2018 OneLogin, Inc.
MIT License

Auxiliary class of OneLogin's Python Toolkit.

"""

import base64
from copy import deepcopy
import calendar
from datetime import datetime
from hashlib import sha1, sha256, sha384, sha512
from isodate import parse_duration as duration_parser
import re
from textwrap import wrap
from functools import wraps
from uuid import uuid4
from xml.dom.minidom import Element

import zlib
import xmlsec

from onelogin.saml2 import compat
from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.xml_utils import OneLogin_Saml2_XML


try:
    from urllib.parse import quote_plus  # py3
except ImportError:
    from urllib import quote_plus  # py2


def return_false_on_exception(func):
    """
    Decorator. When applied to a function, it will, by default, suppress any exceptions
    raised by that function and return False. It may be overridden by passing a
    "raise_exceptions" keyword argument when calling the wrapped function.
    """
    @wraps(func)
    def exceptfalse(*args, **kwargs):
        if not kwargs.pop('raise_exceptions', False):
            try:
                return func(*args, **kwargs)
            except Exception:
                return False
        else:
            return func(*args, **kwargs)
    return exceptfalse


class OneLogin_Saml2_Utils(object):
    """

    Auxiliary class that contains several utility methods to parse time,
    urls, add sign, encrypt, decrypt, sign validation, handle xml ...

    """

    RESPONSE_SIGNATURE_XPATH = '/samlp:Response/ds:Signature'
    ASSERTION_SIGNATURE_XPATH = '/samlp:Response/saml:Assertion/ds:Signature'

    TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
    TIME_FORMAT_2 = "%Y-%m-%dT%H:%M:%S.%fZ"
    TIME_FORMAT_WITH_FRAGMENT = re.compile(r'^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})(\.\d*)?Z?$')

    @staticmethod
    def escape_url(url, lowercase_urlencoding=False):
        """
        escape the non-safe symbols in url
        The encoding used by ADFS 3.0 is not compatible with
        python's quote_plus (ADFS produces lower case hex numbers and quote_plus produces
        upper case hex numbers)
        :param url: the url to escape
        :type url: str

        :param lowercase_urlencoding: lowercase or no
        :type lowercase_urlencoding: boolean

        :return: the escaped url
        :rtype str
        """
        encoded = quote_plus(url)
        return re.sub(r"%[A-F0-9]{2}", lambda m: m.group(0).lower(), encoded) if lowercase_urlencoding else encoded

    @staticmethod
    def b64encode(data):
        """base64 encode"""
        return compat.to_string(base64.b64encode(compat.to_bytes(data)))

    @staticmethod
    def b64decode(data):
        """base64 decode"""
        return base64.b64decode(data)

    @staticmethod
    def decode_base64_and_inflate(value, ignore_zip=False):
        """
        base64 decodes and then inflates according to RFC1951
        :param value: a deflated and encoded string
        :type value: string
        :param ignore_zip: ignore zip errors
        :returns: the string after decoding and inflating
        :rtype: string
        """
        encoded = OneLogin_Saml2_Utils.b64decode(value)
        try:
            return zlib.decompress(encoded, -15)
        except zlib.error:
            if not ignore_zip:
                raise
        return encoded

    @staticmethod
    def deflate_and_base64_encode(value):
        """
        Deflates and then base64 encodes a string
        :param value: The string to deflate and encode
        :type value: string
        :returns: The deflated and encoded string
        :rtype: string
        """
        return OneLogin_Saml2_Utils.b64encode(zlib.compress(compat.to_bytes(value))[2:-4])

    @staticmethod
    def format_cert(cert, heads=True):
        """
        Returns a x509 cert (adding header & footer if required).

        :param cert: A x509 unformatted cert
        :type: string

        :param heads: True if we want to include head and footer
        :type: boolean

        :returns: Formatted cert
        :rtype: string
        """
        x509_cert = cert.replace('\x0D', '')
        x509_cert = x509_cert.replace('\r', '')
        x509_cert = x509_cert.replace('\n', '')
        if len(x509_cert) > 0:
            x509_cert = x509_cert.replace('-----BEGIN CERTIFICATE-----', '')
            x509_cert = x509_cert.replace('-----END CERTIFICATE-----', '')
            x509_cert = x509_cert.replace(' ', '')

            if heads:
                x509_cert = "-----BEGIN CERTIFICATE-----\n" + "\n".join(wrap(x509_cert, 64)) + "\n-----END CERTIFICATE-----\n"

        return x509_cert

    @staticmethod
    def format_private_key(key, heads=True):
        """
        Returns a private key (adding header & footer if required).

        :param key A private key
        :type: string

        :param heads: True if we want to include head and footer
        :type: boolean

        :returns: Formated private key
        :rtype: string
        """
        private_key = key.replace('\x0D', '')
        private_key = private_key.replace('\r', '')
        private_key = private_key.replace('\n', '')
        if len(private_key) > 0:
            if private_key.find('-----BEGIN PRIVATE KEY-----') != -1:
                private_key = private_key.replace('-----BEGIN PRIVATE KEY-----', '')
                private_key = private_key.replace('-----END PRIVATE KEY-----', '')
                private_key = private_key.replace(' ', '')
                if heads:
                    private_key = "-----BEGIN PRIVATE KEY-----\n" + "\n".join(wrap(private_key, 64)) + "\n-----END PRIVATE KEY-----\n"
            else:
                private_key = private_key.replace('-----BEGIN RSA PRIVATE KEY-----', '')
                private_key = private_key.replace('-----END RSA PRIVATE KEY-----', '')
                private_key = private_key.replace(' ', '')
                if heads:
                    private_key = "-----BEGIN RSA PRIVATE KEY-----\n" + "\n".join(wrap(private_key, 64)) + "\n-----END RSA PRIVATE KEY-----\n"
        return private_key

    @staticmethod
    def redirect(url, parameters={}, request_data={}):
        """
        Executes a redirection to the provided url (or return the target url).

        :param url: The target url
        :type: string

        :param parameters: Extra parameters to be passed as part of the url
        :type: dict

        :param request_data: The request as a dict
        :type: dict

        :returns: Url
        :rtype: string
        """
        assert isinstance(url, compat.str_type)
        assert isinstance(parameters, dict)

        if url.startswith('/'):
            url = '%s%s' % (OneLogin_Saml2_Utils.get_self_url_host(request_data), url)

        # Verify that the URL is to a http or https site.
        if re.search('^https?://', url) is None:
            raise OneLogin_Saml2_Error(
                'Redirect to invalid URL: ' + url,
                OneLogin_Saml2_Error.REDIRECT_INVALID_URL
            )

        # Add encoded parameters
        if url.find('?') < 0:
            param_prefix = '?'
        else:
            param_prefix = '&'

        for name, value in parameters.items():

            if value is None:
                param = OneLogin_Saml2_Utils.escape_url(name)
            elif isinstance(value, list):
                param = ''
                for val in value:
                    param += OneLogin_Saml2_Utils.escape_url(name) + '[]=' + OneLogin_Saml2_Utils.escape_url(val) + '&'
                if len(param) > 0:
                    param = param[0:-1]
            else:
                param = OneLogin_Saml2_Utils.escape_url(name) + '=' + OneLogin_Saml2_Utils.escape_url(value)

            if param:
                url += param_prefix + param
                param_prefix = '&'

        return url

    @staticmethod
    def get_self_url_host(request_data):
        """
        Returns the protocol + the current host + the port (if different than
        common ports).

        :param request_data: The request as a dict
        :type: dict

        :return: Url
        :rtype: string
        """
        current_host = OneLogin_Saml2_Utils.get_self_host(request_data)
        port = ''
        if OneLogin_Saml2_Utils.is_https(request_data):
            protocol = 'https'
        else:
            protocol = 'http'

        if 'server_port' in request_data and request_data['server_port'] is not None:
            port_number = str(request_data['server_port'])
            port = ':' + port_number

            if protocol == 'http' and port_number == '80':
                port = ''
            elif protocol == 'https' and port_number == '443':
                port = ''

        return '%s://%s%s' % (protocol, current_host, port)

    @staticmethod
    def get_self_host(request_data):
        """
        Returns the current host.

        :param request_data: The request as a dict
        :type: dict

        :return: The current host
        :rtype: string
        """
        if 'http_host' in request_data:
            current_host = request_data['http_host']
        elif 'server_name' in request_data:
            current_host = request_data['server_name']
        else:
            raise Exception('No hostname defined')

        if ':' in current_host:
            current_host_data = current_host.split(':')
            possible_port = current_host_data[-1]
            try:
                int(possible_port)
                current_host = current_host_data[0]
            except ValueError:
                current_host = ':'.join(current_host_data)

        return current_host

    @staticmethod
    def is_https(request_data):
        """
        Checks if https or http.

        :param request_data: The request as a dict
        :type: dict

        :return: False if https is not active
        :rtype: boolean
        """
        is_https = 'https' in request_data and request_data['https'] != 'off'
        is_https = is_https or ('server_port' in request_data and str(request_data['server_port']) == '443')
        return is_https

    @staticmethod
    def get_self_url_no_query(request_data):
        """
        Returns the URL of the current host + current view.

        :param request_data: The request as a dict
        :type: dict

        :return: The url of current host + current view
        :rtype: string
        """
        self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data)
        script_name = request_data['script_name']
        if script_name:
            if script_name[0] != '/':
                script_name = '/' + script_name
        else:
            script_name = ''
        self_url_no_query = self_url_host + script_name
        if 'path_info' in request_data:
            self_url_no_query += request_data['path_info']

        return self_url_no_query

    @staticmethod
    def get_self_routed_url_no_query(request_data):
        """
        Returns the routed URL of the current host + current view.

        :param request_data: The request as a dict
        :type: dict

        :return: The url of current host + current view
        :rtype: string
        """
        self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data)
        route = ''
        if 'request_uri' in request_data and request_data['request_uri']:
            route = request_data['request_uri']
            if 'query_string' in request_data and request_data['query_string']:
                route = route.replace(request_data['query_string'], '')

        return self_url_host + route

    @staticmethod
    def get_self_url(request_data):
        """
        Returns the URL of the current host + current view + query.

        :param request_data: The request as a dict
        :type: dict

        :return: The url of current host + current view + query
        :rtype: string
        """
        self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data)

        request_uri = ''
        if 'request_uri' in request_data:
            request_uri = request_data['request_uri']
            if not request_uri.startswith('/'):
                match = re.search('^https?://[^/]*(/.*)', request_uri)
                if match is not None:
                    request_uri = match.groups()[0]

        return self_url_host + request_uri

    @staticmethod
    def generate_unique_id():
        """
        Generates an unique string (used for example as ID for assertions).

        :return: A unique string
        :rtype: string
        """
        return 'ONELOGIN_%s' % sha1(compat.to_bytes(uuid4().hex)).hexdigest()

    @staticmethod
    def parse_time_to_SAML(time):
        r"""
        Converts a UNIX timestamp to SAML2 timestamp on the form
        yyyy-mm-ddThh:mm:ss(\.s+)?Z.

        :param time: The time we should convert (DateTime).
        :type: string

        :return: SAML2 timestamp.
        :rtype: string
        """
        data = datetime.utcfromtimestamp(float(time))
        return data.strftime(OneLogin_Saml2_Utils.TIME_FORMAT)

    @staticmethod
    def parse_SAML_to_time(timestr):
        r"""
        Converts a SAML2 timestamp on the form yyyy-mm-ddThh:mm:ss(\.s+)?Z
        to a UNIX timestamp. The sub-second part is ignored.

        :param timestr: The time we should convert (SAML Timestamp).
        :type: string

        :return: Converted to a unix timestamp.
        :rtype: int
        """
        try:
            data = datetime.strptime(timestr, OneLogin_Saml2_Utils.TIME_FORMAT)
        except ValueError:
            try:
                data = datetime.strptime(timestr, OneLogin_Saml2_Utils.TIME_FORMAT_2)
            except ValueError:
                elem = OneLogin_Saml2_Utils.TIME_FORMAT_WITH_FRAGMENT.match(timestr)
                if not elem:
                    raise Exception("time data %s does not match format %s" % (timestr, r'yyyy-mm-ddThh:mm:ss(\.s+)?Z'))
                data = datetime.strptime(elem.groups()[0] + "Z", OneLogin_Saml2_Utils.TIME_FORMAT)

        return calendar.timegm(data.utctimetuple())

    @staticmethod
    def now():
        """
        :return: unix timestamp of actual time.
        :rtype: int
        """
        return calendar.timegm(datetime.utcnow().utctimetuple())

    @staticmethod
    def parse_duration(duration, timestamp=None):
        """
        Interprets a ISO8601 duration value relative to a given timestamp.

        :param duration: The duration, as a string.
        :type: string

        :param timestamp: The unix timestamp we should apply the duration to.
                          Optional, default to the current time.
        :type: string

        :return: The new timestamp, after the duration is applied.
        :rtype: int
        """
        assert isinstance(duration, compat.str_type)
        assert timestamp is None or isinstance(timestamp, int)

        timedelta = duration_parser(duration)
        if timestamp is None:
            data = datetime.utcnow() + timedelta
        else:
            data = datetime.utcfromtimestamp(timestamp) + timedelta
        return calendar.timegm(data.utctimetuple())

    @staticmethod
    def get_expire_time(cache_duration=None, valid_until=None):
        """
        Compares 2 dates and returns the earliest.

        :param cache_duration: The duration, as a string.
        :type: string

        :param valid_until: The valid until date, as a string or as a timestamp
        :type: string

        :return: The expiration time.
        :rtype: int
        """
        expire_time = None

        if cache_duration is not None:
            expire_time = OneLogin_Saml2_Utils.parse_duration(cache_duration)

        if valid_until is not None:
            if isinstance(valid_until, int):
                valid_until_time = valid_until
            else:
                valid_until_time = OneLogin_Saml2_Utils.parse_SAML_to_time(valid_until)
            if expire_time is None or expire_time > valid_until_time:
                expire_time = valid_until_time

        if expire_time is not None:
            return '%d' % expire_time
        return None

    @staticmethod
    def delete_local_session(callback=None):
        """
        Deletes the local session.
        """

        if callback is not None:
            callback()

    @staticmethod
    def calculate_x509_fingerprint(x509_cert, alg='sha1'):
        """
        Calculates the fingerprint of a formatted x509cert.

        :param x509_cert: x509 cert formatted
        :type: string

        :param alg: The algorithm to build the fingerprint
        :type: string

        :returns: fingerprint
        :rtype: string
        """
        assert isinstance(x509_cert, compat.str_type)

        lines = x509_cert.split('\n')
        data = ''
        inData = False

        for line in lines:
            # Remove '\r' from end of line if present.
            line = line.rstrip()
            if not inData:
                if line == '-----BEGIN CERTIFICATE-----':
                    inData = True
                elif line == '-----BEGIN PUBLIC KEY-----' or line == '-----BEGIN RSA PRIVATE KEY-----':
                    # This isn't an X509 certificate.
                    return None
            else:
                if line == '-----END CERTIFICATE-----':
                    break

                # Append the current line to the certificate data.
                data += line

        if not data:
            return None

        decoded_data = base64.b64decode(compat.to_bytes(data))

        if alg == 'sha512':
            fingerprint = sha512(decoded_data)
        elif alg == 'sha384':
            fingerprint = sha384(decoded_data)
        elif alg == 'sha256':
            fingerprint = sha256(decoded_data)
        else:
            fingerprint = sha1(decoded_data)

        return fingerprint.hexdigest().lower()

    @staticmethod
    def format_finger_print(fingerprint):
        """
        Formats a fingerprint.

        :param fingerprint: fingerprint
        :type: string

        :returns: Formatted fingerprint
        :rtype: string
        """
        formatted_fingerprint = fingerprint.replace(':', '')
        return formatted_fingerprint.lower()

    @staticmethod
    def generate_name_id(value, sp_nq, sp_format=None, cert=None, debug=False, nq=None):
        """
        Generates a nameID.

        :param value: fingerprint
        :type: string

        :param sp_nq: SP Name Qualifier
        :type: string

        :param sp_format: SP Format
        :type: string

        :param cert: IdP Public Cert to encrypt the nameID
        :type: string

        :param debug: Activate the xmlsec debug
        :type: bool

        :returns: DOMElement | XMLSec nameID
        :rtype: string

        :param nq: IDP Name Qualifier
        :type: string
        """

        root = OneLogin_Saml2_XML.make_root("{%s}container" % OneLogin_Saml2_Constants.NS_SAML)
        name_id = OneLogin_Saml2_XML.make_child(root, '{%s}NameID' % OneLogin_Saml2_Constants.NS_SAML)
        if sp_nq is not None:
            name_id.set('SPNameQualifier', sp_nq)
        if sp_format is not None:
            name_id.set('Format', sp_format)
        if nq is not None:
            name_id.set('NameQualifier', nq)
        name_id.text = value

        if cert is not None:
            xmlsec.enable_debug_trace(debug)

            # Load the public cert
            manager = xmlsec.KeysManager()
            manager.add_key(xmlsec.Key.from_memory(cert, xmlsec.KeyFormat.CERT_PEM, None))

            # Prepare for encryption
            enc_data = xmlsec.template.encrypted_data_create(
                root, xmlsec.Transform.AES128, type=xmlsec.EncryptionType.ELEMENT, ns="xenc")

            xmlsec.template.encrypted_data_ensure_cipher_value(enc_data)
            key_info = xmlsec.template.encrypted_data_ensure_key_info(enc_data, ns="dsig")
            enc_key = xmlsec.template.add_encrypted_key(key_info, xmlsec.Transform.RSA_OAEP)
            xmlsec.template.encrypted_data_ensure_cipher_value(enc_key)

            # Encrypt!
            enc_ctx = xmlsec.EncryptionContext(manager)
            enc_ctx.key = xmlsec.Key.generate(xmlsec.KeyData.AES, 128, xmlsec.KeyDataType.SESSION)
            enc_data = enc_ctx.encrypt_xml(enc_data, name_id)
            return '<saml:EncryptedID>' + compat.to_string(OneLogin_Saml2_XML.to_string(enc_data)) + '</saml:EncryptedID>'
        else:
            return OneLogin_Saml2_XML.extract_tag_text(root, "saml:NameID")

    @staticmethod
    def get_status(dom):
        """
        Gets Status from a Response.

        :param dom: The Response as XML
        :type: Document

        :returns: The Status, an array with the code and a message.
        :rtype: dict
        """
        status = {}

        status_entry = OneLogin_Saml2_XML.query(dom, '/samlp:Response/samlp:Status')
        if len(status_entry) != 1:
            raise OneLogin_Saml2_ValidationError(
                'Missing Status on response',
                OneLogin_Saml2_ValidationError.MISSING_STATUS
            )

        code_entry = OneLogin_Saml2_XML.query(dom, '/samlp:Response/samlp:Status/samlp:StatusCode', status_entry[0])
        if len(code_entry) != 1:
            raise OneLogin_Saml2_ValidationError(
                'Missing Status Code on response',
                OneLogin_Saml2_ValidationError.MISSING_STATUS_CODE
            )
        code = code_entry[0].values()[0]
        status['code'] = code

        status['msg'] = ''
        message_entry = OneLogin_Saml2_XML.query(dom, '/samlp:Response/samlp:Status/samlp:StatusMessage', status_entry[0])
        if len(message_entry) == 0:
            subcode_entry = OneLogin_Saml2_XML.query(dom, '/samlp:Response/samlp:Status/samlp:StatusCode/samlp:StatusCode', status_entry[0])
            if len(subcode_entry) == 1:
                status['msg'] = subcode_entry[0].values()[0]
        elif len(message_entry) == 1:
            status['msg'] = OneLogin_Saml2_XML.element_text(message_entry[0])

        return status

    @staticmethod
    def decrypt_element(encrypted_data, key, debug=False, inplace=False):
        """
        Decrypts an encrypted element.

        :param encrypted_data: The encrypted data.
        :type: lxml.etree.Element | DOMElement | basestring

        :param key: The key.
        :type: string

        :param debug: Activate the xmlsec debug
        :type: bool

        :param inplace: update passed data with decrypted result
        :type: bool

        :returns: The decrypted element.
        :rtype: lxml.etree.Element
        """

        if isinstance(encrypted_data, Element):
            encrypted_data = OneLogin_Saml2_XML.to_etree(str(encrypted_data.toxml()))
        if not inplace and isinstance(encrypted_data, OneLogin_Saml2_XML._element_class):
            encrypted_data = deepcopy(encrypted_data)
        elif isinstance(encrypted_data, OneLogin_Saml2_XML._text_class):
            encrypted_data = OneLogin_Saml2_XML._parse_etree(encrypted_data)

        xmlsec.enable_debug_trace(debug)
        manager = xmlsec.KeysManager()

        manager.add_key(xmlsec.Key.from_memory(key, xmlsec.KeyFormat.PEM, None))
        enc_ctx = xmlsec.EncryptionContext(manager)
        return enc_ctx.decrypt(encrypted_data)

    @staticmethod
    def add_sign(xml, key, cert, debug=False, sign_algorithm=OneLogin_Saml2_Constants.RSA_SHA1, digest_algorithm=OneLogin_Saml2_Constants.SHA1):
        """
        Adds signature key and senders certificate to an element (Message or
        Assertion).

        :param xml: The element we should sign
        :type: string | Document

        :param key: The private key
        :type: string

        :param cert: The public
        :type: string

        :param debug: Activate the xmlsec debug
        :type: bool

        :param sign_algorithm: Signature algorithm method
        :type sign_algorithm: string

        :param digest_algorithm: Digest algorithm method
        :type digest_algorithm: string

        :returns: Signed XML
        :rtype: string
        """
        if xml is None or xml == '':
            raise Exception('Empty string supplied as input')

        elem = OneLogin_Saml2_XML.to_etree(xml)

        sign_algorithm_transform_map = {
            OneLogin_Saml2_Constants.DSA_SHA1: xmlsec.Transform.DSA_SHA1,
            OneLogin_Saml2_Constants.RSA_SHA1: xmlsec.Transform.RSA_SHA1,
            OneLogin_Saml2_Constants.RSA_SHA256: xmlsec.Transform.RSA_SHA256,
            OneLogin_Saml2_Constants.RSA_SHA384: xmlsec.Transform.RSA_SHA384,
            OneLogin_Saml2_Constants.RSA_SHA512: xmlsec.Transform.RSA_SHA512
        }
        sign_algorithm_transform = sign_algorithm_transform_map.get(sign_algorithm, xmlsec.Transform.RSA_SHA1)

        signature = xmlsec.template.create(elem, xmlsec.Transform.EXCL_C14N, sign_algorithm_transform, ns='ds')

        issuer = OneLogin_Saml2_XML.query(elem, '//saml:Issuer')
        if len(issuer) > 0:
            issuer = issuer[0]
            issuer.addnext(signature)
            elem_to_sign = issuer.getparent()
        else:
            entity_descriptor = OneLogin_Saml2_XML.query(elem, '//md:EntityDescriptor')
            if len(entity_descriptor) > 0:
                elem.insert(0, signature)
            else:
                elem[0].insert(0, signature)
            elem_to_sign = elem

        elem_id = elem_to_sign.get('ID', None)
        if elem_id is not None:
            if elem_id:
                elem_id = '#' + elem_id
        else:
            generated_id = generated_id = OneLogin_Saml2_Utils.generate_unique_id()
            elem_id = '#' + generated_id
            elem_to_sign.attrib['ID'] = generated_id

        xmlsec.enable_debug_trace(debug)
        xmlsec.tree.add_ids(elem_to_sign, ["ID"])

        digest_algorithm_transform_map = {
            OneLogin_Saml2_Constants.SHA1: xmlsec.Transform.SHA1,
            OneLogin_Saml2_Constants.SHA256: xmlsec.Transform.SHA256,
            OneLogin_Saml2_Constants.SHA384: xmlsec.Transform.SHA384,
            OneLogin_Saml2_Constants.SHA512: xmlsec.Transform.SHA512
        }
        digest_algorithm_transform = digest_algorithm_transform_map.get(digest_algorithm, xmlsec.Transform.SHA1)

        ref = xmlsec.template.add_reference(signature, digest_algorithm_transform, uri=elem_id)
        xmlsec.template.add_transform(ref, xmlsec.Transform.ENVELOPED)
        xmlsec.template.add_transform(ref, xmlsec.Transform.EXCL_C14N)
        key_info = xmlsec.template.ensure_key_info(signature)
        xmlsec.template.add_x509_data(key_info)

        dsig_ctx = xmlsec.SignatureContext()
        sign_key = xmlsec.Key.from_memory(key, xmlsec.KeyFormat.PEM, None)
        sign_key.load_cert_from_memory(cert, xmlsec.KeyFormat.PEM)

        dsig_ctx.key = sign_key
        dsig_ctx.sign(signature)

        return OneLogin_Saml2_XML.to_string(elem)

    @staticmethod
    @return_false_on_exception
    def validate_sign(xml, cert=None, fingerprint=None, fingerprintalg='sha1', validatecert=False, debug=False, xpath=None, multicerts=None):
        """
        Validates a signature (Message or Assertion).

        :param xml: The element we should validate
        :type: string | Document

        :param cert: The public cert
        :type: string

        :param fingerprint: The fingerprint of the public cert
        :type: string

        :param fingerprintalg: The algorithm used to build the fingerprint
        :type: string

        :param validatecert: If true, will verify the signature and if the cert is valid.
        :type: bool

        :param debug: Activate the xmlsec debug
        :type: bool

        :param xpath: The xpath of the signed element
        :type: string

        :param multicerts: Multiple public certs
        :type: list

        :param raise_exceptions: Whether to return false on failure or raise an exception
        :type raise_exceptions: Boolean
        """
        if xml is None or xml == '':
            raise Exception('Empty string supplied as input')

        elem = OneLogin_Saml2_XML.to_etree(xml)
        xmlsec.enable_debug_trace(debug)
        xmlsec.tree.add_ids(elem, ["ID"])

        if xpath:
            signature_nodes = OneLogin_Saml2_XML.query(elem, xpath)
        else:
            signature_nodes = OneLogin_Saml2_XML.query(elem, OneLogin_Saml2_Utils.RESPONSE_SIGNATURE_XPATH)

            if len(signature_nodes) == 0:
                signature_nodes = OneLogin_Saml2_XML.query(elem, OneLogin_Saml2_Utils.ASSERTION_SIGNATURE_XPATH)

        if len(signature_nodes) == 1:
            signature_node = signature_nodes[0]

            if not multicerts:
                return OneLogin_Saml2_Utils.validate_node_sign(signature_node, elem, cert, fingerprint, fingerprintalg, validatecert, debug, raise_exceptions=True)
            else:
                # If multiple certs are provided, I may ignore cert and
                # fingerprint provided by the method and just check the
                # certs multicerts
                fingerprint = fingerprintalg = None
                for cert in multicerts:
                    if OneLogin_Saml2_Utils.validate_node_sign(signature_node, elem, cert, fingerprint, fingerprintalg, validatecert, False, raise_exceptions=False):
                        return True
                raise OneLogin_Saml2_ValidationError(
                    'Signature validation failed. SAML Response rejected.',
                    OneLogin_Saml2_ValidationError.INVALID_SIGNATURE
                )
        else:
            raise OneLogin_Saml2_ValidationError(
                'Expected exactly one signature node; got {}.'.format(len(signature_nodes)),
                OneLogin_Saml2_ValidationError.WRONG_NUMBER_OF_SIGNATURES
            )

    @staticmethod
    @return_false_on_exception
    def validate_metadata_sign(xml, cert=None, fingerprint=None, fingerprintalg='sha1', validatecert=False, debug=False):
        """
        Validates a signature of a EntityDescriptor.

        :param xml: The element we should validate
        :type: string | Document

        :param cert: The public cert
        :type: string

        :param fingerprint: The fingerprint of the public cert
        :type: string

        :param fingerprintalg: The algorithm used to build the fingerprint
        :type: string

        :param validatecert: If true, will verify the signature and if the cert is valid.
        :type: bool

        :param debug: Activate the xmlsec debug
        :type: bool

        :param raise_exceptions: Whether to return false on failure or raise an exception
        :type raise_exceptions: Boolean
        """
        if xml is None or xml == '':
            raise Exception('Empty string supplied as input')

        elem = OneLogin_Saml2_XML.to_etree(xml)
        xmlsec.enable_debug_trace(debug)
        xmlsec.tree.add_ids(elem, ["ID"])

        signature_nodes = OneLogin_Saml2_XML.query(elem, '/md:EntitiesDescriptor/ds:Signature')

        if len(signature_nodes) == 0:
            signature_nodes += OneLogin_Saml2_XML.query(elem, '/md:EntityDescriptor/ds:Signature')

            if len(signature_nodes) == 0:
                signature_nodes += OneLogin_Saml2_XML.query(elem, '/md:EntityDescriptor/md:SPSSODescriptor/ds:Signature')
                signature_nodes += OneLogin_Saml2_XML.query(elem, '/md:EntityDescriptor/md:IDPSSODescriptor/ds:Signature')

        if len(signature_nodes) > 0:
            for signature_node in signature_nodes:
                # Raises expection if invalid
                OneLogin_Saml2_Utils.validate_node_sign(signature_node, elem, cert, fingerprint, fingerprintalg, validatecert, debug, raise_exceptions=True)
            return True
        else:
            raise Exception('Could not validate metadata signature: No signature nodes found.')

    @staticmethod
    @return_false_on_exception
    def validate_node_sign(signature_node, elem, cert=None, fingerprint=None, fingerprintalg='sha1', validatecert=False, debug=False):
        """
        Validates a signature node.

        :param signature_node: The signature node
        :type: Node

        :param xml: The element we should validate
        :type: Document

        :param cert: The public cert
        :type: string

        :param fingerprint: The fingerprint of the public cert
        :type: string

        :param fingerprintalg: The algorithm used to build the fingerprint
        :type: string

        :param validatecert: If true, will verify the signature and if the cert is valid.
        :type: bool

        :param debug: Activate the xmlsec debug
        :type: bool

        :param raise_exceptions: Whether to return false on failure or raise an exception
        :type raise_exceptions: Boolean
        """
        if (cert is None or cert == '') and fingerprint:
            x509_certificate_nodes = OneLogin_Saml2_XML.query(signature_node, '//ds:Signature/ds:KeyInfo/ds:X509Data/ds:X509Certificate')
            if len(x509_certificate_nodes) > 0:
                x509_certificate_node = x509_certificate_nodes[0]
                x509_cert_value = OneLogin_Saml2_XML.element_text(x509_certificate_node)
                x509_cert_value_formatted = OneLogin_Saml2_Utils.format_cert(x509_cert_value)
                x509_fingerprint_value = OneLogin_Saml2_Utils.calculate_x509_fingerprint(x509_cert_value_formatted, fingerprintalg)
                if fingerprint == x509_fingerprint_value:
                    cert = x509_cert_value_formatted

        if cert is None or cert == '':
            raise OneLogin_Saml2_Error(
                'Could not validate node signature: No certificate provided.',
                OneLogin_Saml2_Error.CERT_NOT_FOUND
            )

        # Check if Reference URI is empty
        # reference_elem = OneLogin_Saml2_XML.query(signature_node, '//ds:Reference')
        # if len(reference_elem) > 0:
        #     if reference_elem[0].get('URI') == '':
        #         reference_elem[0].set('URI', '#%s' % signature_node.getparent().get('ID'))

        if validatecert:
            manager = xmlsec.KeysManager()
            manager.load_cert_from_memory(cert, xmlsec.KeyFormat.CERT_PEM, xmlsec.KeyDataType.TRUSTED)
            dsig_ctx = xmlsec.SignatureContext(manager)
        else:
            dsig_ctx = xmlsec.SignatureContext()
            dsig_ctx.key = xmlsec.Key.from_memory(cert, xmlsec.KeyFormat.CERT_PEM, None)

        dsig_ctx.set_enabled_key_data([xmlsec.KeyData.X509])

        try:
            dsig_ctx.verify(signature_node)
        except Exception as err:
            raise OneLogin_Saml2_ValidationError(
                'Signature validation failed. SAML Response rejected. %s',
                OneLogin_Saml2_ValidationError.INVALID_SIGNATURE,
                str(err)
            )

        return True

    @staticmethod
    def sign_binary(msg, key, algorithm=xmlsec.Transform.RSA_SHA1, debug=False):
        """
        Sign binary message

        :param msg: The element we should validate
        :type: bytes

        :param key: The private key
        :type: string

        :param debug: Activate the xmlsec debug
        :type: bool

        :return signed message
        :rtype str
        """

        if isinstance(msg, str):
            msg = msg.encode('utf8')

        xmlsec.enable_debug_trace(debug)
        dsig_ctx = xmlsec.SignatureContext()
        dsig_ctx.key = xmlsec.Key.from_memory(key, xmlsec.KeyFormat.PEM, None)
        return dsig_ctx.sign_binary(compat.to_bytes(msg), algorithm)

    @staticmethod
    def validate_binary_sign(signed_query, signature, cert=None, algorithm=OneLogin_Saml2_Constants.RSA_SHA1, debug=False):
        """
        Validates signed binary data (Used to validate GET Signature).

        :param signed_query: The element we should validate
        :type: string


        :param signature: The signature that will be validate
        :type: string

        :param cert: The public cert
        :type: string

        :param algorithm: Signature algorithm
        :type: string

        :param debug: Activate the xmlsec debug
        :type: bool
        """
        try:
            xmlsec.enable_debug_trace(debug)
            dsig_ctx = xmlsec.SignatureContext()
            dsig_ctx.key = xmlsec.Key.from_memory(cert, xmlsec.KeyFormat.CERT_PEM, None)

            sign_algorithm_transform_map = {
                OneLogin_Saml2_Constants.DSA_SHA1: xmlsec.Transform.DSA_SHA1,
                OneLogin_Saml2_Constants.RSA_SHA1: xmlsec.Transform.RSA_SHA1,
                OneLogin_Saml2_Constants.RSA_SHA256: xmlsec.Transform.RSA_SHA256,
                OneLogin_Saml2_Constants.RSA_SHA384: xmlsec.Transform.RSA_SHA384,
                OneLogin_Saml2_Constants.RSA_SHA512: xmlsec.Transform.RSA_SHA512
            }
            sign_algorithm_transform = sign_algorithm_transform_map.get(algorithm, xmlsec.Transform.RSA_SHA1)

            dsig_ctx.verify_binary(compat.to_bytes(signed_query),
                                   sign_algorithm_transform,
                                   compat.to_bytes(signature))
            return True
        except xmlsec.Error as e:
            if debug:
                print(e)
            return False