# ===================================================================
#
# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the
#    distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# ===================================================================

import unittest

from Cryptodome.SelfTest.loader import load_tests
from Cryptodome.SelfTest.st_common import list_test_cases
from Cryptodome.Util.py3compat import tobytes, b, unhexlify
from Cryptodome.Cipher import AES, DES3, DES
from Cryptodome.Hash import SHAKE128

def get_tag_random(tag, length):
    return SHAKE128.new(data=tobytes(tag)).read(length)

class BlockChainingTests(unittest.TestCase):

    key_128 = get_tag_random("key_128", 16)
    key_192 = get_tag_random("key_192", 24)
    iv_128 = get_tag_random("iv_128", 16)
    iv_64 = get_tag_random("iv_64", 8)
    data_128 = get_tag_random("data_128", 16)

    def test_loopback_128(self):
        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        pt = get_tag_random("plaintext", 16 * 100)
        ct = cipher.encrypt(pt)

        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        pt2 = cipher.decrypt(ct)
        self.assertEqual(pt, pt2)

    def test_loopback_64(self):
        cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
        pt = get_tag_random("plaintext", 8 * 100)
        ct = cipher.encrypt(pt)

        cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
        pt2 = cipher.decrypt(ct)
        self.assertEqual(pt, pt2)

    def test_iv(self):
        # If not passed, the iv is created randomly
        cipher = AES.new(self.key_128, self.aes_mode)
        iv1 = cipher.iv
        cipher = AES.new(self.key_128, self.aes_mode)
        iv2 = cipher.iv
        self.assertNotEqual(iv1, iv2)
        self.assertEqual(len(iv1), 16)

        # IV can be passed in uppercase or lowercase
        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        ct = cipher.encrypt(self.data_128)

        cipher = AES.new(self.key_128, self.aes_mode, iv=self.iv_128)
        self.assertEquals(ct, cipher.encrypt(self.data_128))

        cipher = AES.new(self.key_128, self.aes_mode, IV=self.iv_128)
        self.assertEquals(ct, cipher.encrypt(self.data_128))

    def test_iv_must_be_bytes(self):
        self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
                          iv = u'test1234567890-*')

    def test_only_one_iv(self):
        # Only one IV/iv keyword allowed
        self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
                          iv=self.iv_128, IV=self.iv_128)

    def test_iv_with_matching_length(self):
        self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
                          b(""))
        self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
                          self.iv_128[:15])
        self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
                          self.iv_128 + b("0"))

    def test_block_size_128(self):
        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        self.assertEqual(cipher.block_size, AES.block_size)

    def test_block_size_64(self):
        cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
        self.assertEqual(cipher.block_size, DES3.block_size)

    def test_unaligned_data_128(self):
        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        for wrong_length in xrange(1,16):
            self.assertRaises(ValueError, cipher.encrypt, b("5") * wrong_length)

        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        for wrong_length in xrange(1,16):
            self.assertRaises(ValueError, cipher.decrypt, b("5") * wrong_length)

    def test_unaligned_data_64(self):
        cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
        for wrong_length in xrange(1,8):
            self.assertRaises(ValueError, cipher.encrypt, b("5") * wrong_length)

        cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
        for wrong_length in xrange(1,8):
            self.assertRaises(ValueError, cipher.decrypt, b("5") * wrong_length)

    def test_IV_iv_attributes(self):
        data = get_tag_random("data", 16 * 100)
        for func in "encrypt", "decrypt":
            cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
            getattr(cipher, func)(data)
            self.assertEqual(cipher.iv, self.iv_128)
            self.assertEqual(cipher.IV, self.iv_128)

    def test_unknown_parameters(self):
        self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
                          self.iv_128, 7)
        self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
                          iv=self.iv_128, unknown=7)
        # But some are only known by the base cipher (e.g. use_aesni consumed by the AES module)
        AES.new(self.key_128, self.aes_mode, iv=self.iv_128, use_aesni=False)

    def test_null_encryption_decryption(self):
        for func in "encrypt", "decrypt":
            cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
            result = getattr(cipher, func)(b(""))
            self.assertEqual(result, b(""))

    def test_either_encrypt_or_decrypt(self):
        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        cipher.encrypt(b(""))
        self.assertRaises(TypeError, cipher.decrypt, b(""))

        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        cipher.decrypt(b(""))
        self.assertRaises(TypeError, cipher.encrypt, b(""))

    def test_data_must_be_bytes(self):
        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')

        cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
        self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')


class CbcTests(BlockChainingTests):
    aes_mode = AES.MODE_CBC
    des3_mode = DES3.MODE_CBC


class NistBlockChainingVectors(unittest.TestCase):

    def _do_kat_aes_test(self, file_name):
        test_vectors = load_tests(("Cryptodome", "SelfTest", "Cipher", "test_vectors", "AES"),
                                  file_name,
                                  "AES KAT",
                                  { "count" : lambda x: int(x) } )
        assert(test_vectors)

        direction = None
        for tv in test_vectors:

            # The test vector file contains some directive lines
            if isinstance(tv, basestring):
                direction = tv
                continue

            self.description = tv.desc

            cipher = AES.new(tv.key, self.aes_mode, tv.iv)
            if direction == "[ENCRYPT]":
                self.assertEqual(cipher.encrypt(tv.plaintext), tv.ciphertext)
            elif direction == "[DECRYPT]":
                self.assertEqual(cipher.decrypt(tv.ciphertext), tv.plaintext)
            else:
                assert False

    # See Section 6.4.2 in AESAVS
    def _do_mct_aes_test(self, file_name):
        test_vectors = load_tests(("Cryptodome", "SelfTest", "Cipher", "test_vectors", "AES"),
                                  file_name,
                                  "AES Montecarlo",
                                  { "count" : lambda x: int(x) } )
        assert(test_vectors)

        direction = None
        for tv in test_vectors:

            # The test vector file contains some directive lines
            if isinstance(tv, basestring):
                direction = tv
                continue

            self.description = tv.desc
            cipher = AES.new(tv.key, self.aes_mode, tv.iv)

            if direction == '[ENCRYPT]':
                cts = [ tv.iv ]
                for count in xrange(1000):
                    cts.append(cipher.encrypt(tv.plaintext))
                    tv.plaintext = cts[-2]
                self.assertEqual(cts[-1], tv.ciphertext)
            elif direction == '[DECRYPT]':
                pts = [ tv.iv]
                for count in xrange(1000):
                    pts.append(cipher.decrypt(tv.ciphertext))
                    tv.ciphertext = pts[-2]
                self.assertEqual(pts[-1], tv.plaintext)
            else:
                assert False

    def _do_tdes_test(self, file_name):
        test_vectors = load_tests(("Cryptodome", "SelfTest", "Cipher", "test_vectors", "TDES"),
                                  file_name,
                                  "TDES CBC KAT",
                                  { "count" : lambda x: int(x) } )
        assert(test_vectors)

        direction = None
        for tv in test_vectors:

            # The test vector file contains some directive lines
            if isinstance(tv, basestring):
                direction = tv
                continue

            self.description = tv.desc
            if hasattr(tv, "keys"):
                cipher = DES.new(tv.keys, self.des_mode, tv.iv)
            else:
                if tv.key1 != tv.key3:
                    key = tv.key1 + tv.key2 + tv.key3  # Option 3
                else:
                    key = tv.key1 + tv.key2            # Option 2
                cipher = DES3.new(key, self.des3_mode, tv.iv)

            if direction == "[ENCRYPT]":
                self.assertEqual(cipher.encrypt(tv.plaintext), tv.ciphertext)
            elif direction == "[DECRYPT]":
                self.assertEqual(cipher.decrypt(tv.ciphertext), tv.plaintext)
            else:
                assert False


class NistCbcVectors(NistBlockChainingVectors):
    aes_mode = AES.MODE_CBC
    des_mode = DES.MODE_CBC
    des3_mode = DES3.MODE_CBC


# Create one test method per file
nist_aes_kat_mmt_files = (
    # KAT
    "CBCGFSbox128.rsp",
    "CBCGFSbox192.rsp",
    "CBCGFSbox256.rsp",
    "CBCKeySbox128.rsp",
    "CBCKeySbox192.rsp",
    "CBCKeySbox256.rsp",
    "CBCVarKey128.rsp",
    "CBCVarKey192.rsp",
    "CBCVarKey256.rsp",
    "CBCVarTxt128.rsp",
    "CBCVarTxt192.rsp",
    "CBCVarTxt256.rsp",
    # MMT
    "CBCMMT128.rsp",
    "CBCMMT192.rsp",
    "CBCMMT256.rsp",
    )
nist_aes_mct_files = (
    "CBCMCT128.rsp",
    "CBCMCT192.rsp",
    "CBCMCT256.rsp",
    )

for file_name in nist_aes_kat_mmt_files:
    def new_func(self, file_name=file_name):
        self._do_kat_aes_test(file_name)
    setattr(NistCbcVectors, "test_AES_" + file_name, new_func)

for file_name in nist_aes_mct_files:
    def new_func(self, file_name=file_name):
        self._do_mct_aes_test(file_name)
    setattr(NistCbcVectors, "test_AES_" + file_name, new_func)
del file_name, new_func

nist_tdes_files = (
    "TCBCMMT2.rsp",    # 2TDES
    "TCBCMMT3.rsp",    # 3TDES
    "TCBCinvperm.rsp", # Single DES
    "TCBCpermop.rsp",
    "TCBCsubtab.rsp",
    "TCBCvarkey.rsp",
    "TCBCvartext.rsp",
    )

for file_name in nist_tdes_files:
    def new_func(self, file_name=file_name):
        self._do_tdes_test(file_name)
    setattr(NistCbcVectors, "test_TDES_" + file_name, new_func)

# END OF NIST CBC TEST VECTORS


class SP800TestVectors(unittest.TestCase):
    """Class exercising the CBC test vectors found in Section F.2
    of NIST SP 800-3A"""

    def test_aes_128(self):
        key =           '2b7e151628aed2a6abf7158809cf4f3c'
        iv =            '000102030405060708090a0b0c0d0e0f'
        plaintext =     '6bc1bee22e409f96e93d7e117393172a' +\
                        'ae2d8a571e03ac9c9eb76fac45af8e51' +\
                        '30c81c46a35ce411e5fbc1191a0a52ef' +\
                        'f69f2445df4f9b17ad2b417be66c3710'
        ciphertext =    '7649abac8119b246cee98e9b12e9197d' +\
                        '5086cb9b507219ee95db113a917678b2' +\
                        '73bed6b8e3c1743b7116e69e22229516' +\
                        '3ff1caa1681fac09120eca307586e1a7'

        key = unhexlify(key)
        iv = unhexlify(iv)
        plaintext = unhexlify(plaintext)
        ciphertext = unhexlify(ciphertext)

        cipher = AES.new(key, AES.MODE_CBC, iv)
        self.assertEqual(cipher.encrypt(plaintext), ciphertext)
        cipher = AES.new(key, AES.MODE_CBC, iv)
        self.assertEqual(cipher.decrypt(ciphertext), plaintext)

    def test_aes_192(self):
        key =           '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b'
        iv =            '000102030405060708090a0b0c0d0e0f'
        plaintext =     '6bc1bee22e409f96e93d7e117393172a' +\
                        'ae2d8a571e03ac9c9eb76fac45af8e51' +\
                        '30c81c46a35ce411e5fbc1191a0a52ef' +\
                        'f69f2445df4f9b17ad2b417be66c3710'
        ciphertext =    '4f021db243bc633d7178183a9fa071e8' +\
                        'b4d9ada9ad7dedf4e5e738763f69145a' +\
                        '571b242012fb7ae07fa9baac3df102e0' +\
                        '08b0e27988598881d920a9e64f5615cd'

        key = unhexlify(key)
        iv = unhexlify(iv)
        plaintext = unhexlify(plaintext)
        ciphertext = unhexlify(ciphertext)

        cipher = AES.new(key, AES.MODE_CBC, iv)
        self.assertEqual(cipher.encrypt(plaintext), ciphertext)
        cipher = AES.new(key, AES.MODE_CBC, iv)
        self.assertEqual(cipher.decrypt(ciphertext), plaintext)

    def test_aes_256(self):
        key =           '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4'
        iv =            '000102030405060708090a0b0c0d0e0f'
        plaintext =     '6bc1bee22e409f96e93d7e117393172a' +\
                        'ae2d8a571e03ac9c9eb76fac45af8e51' +\
                        '30c81c46a35ce411e5fbc1191a0a52ef' +\
                        'f69f2445df4f9b17ad2b417be66c3710'
        ciphertext =    'f58c4c04d6e5f1ba779eabfb5f7bfbd6' +\
                        '9cfc4e967edb808d679f777bc6702c7d' +\
                        '39f23369a9d9bacfa530e26304231461' +\
                        'b2eb05e2c39be9fcda6c19078c6a9d1b'

        key = unhexlify(key)
        iv = unhexlify(iv)
        plaintext = unhexlify(plaintext)
        ciphertext = unhexlify(ciphertext)

        cipher = AES.new(key, AES.MODE_CBC, iv)
        self.assertEqual(cipher.encrypt(plaintext), ciphertext)
        cipher = AES.new(key, AES.MODE_CBC, iv)
        self.assertEqual(cipher.decrypt(ciphertext), plaintext)


def get_tests(config={}):
    tests = []
    tests += list_test_cases(CbcTests)
    tests += list_test_cases(NistCbcVectors)
    tests += list_test_cases(SP800TestVectors)
    return tests


if __name__ == '__main__':
    suite = lambda: unittest.TestSuite(get_tests())
    unittest.main(defaultTest='suite')