# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. """ Test cases for the mock key manager. """ from cryptography.hazmat import backends from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from oslo_context import context from castellan.common import exception from castellan.common.objects import symmetric_key as sym_key from castellan.tests.unit.key_manager import mock_key_manager as mock_key_mgr from castellan.tests.unit.key_manager import test_key_manager as test_key_mgr def get_cryptography_private_key(private_key): crypto_private_key = serialization.load_der_private_key( bytes(private_key.get_encoded()), password=None, backend=backends.default_backend()) return crypto_private_key def get_cryptography_public_key(public_key): crypto_public_key = serialization.load_der_public_key( bytes(public_key.get_encoded()), backend=backends.default_backend()) return crypto_public_key class MockKeyManagerTestCase(test_key_mgr.KeyManagerTestCase): def _create_key_manager(self): return mock_key_mgr.MockKeyManager() def setUp(self): super(MockKeyManagerTestCase, self).setUp() self.context = context.RequestContext('fake', 'fake') def cleanUp(self): super(MockKeyManagerTestCase, self).cleanUp() self.key_mgr.keys = {} def test_create_key(self): key_id_1 = self.key_mgr.create_key(self.context) key_id_2 = self.key_mgr.create_key(self.context) # ensure that the UUIDs are unique self.assertNotEqual(key_id_1, key_id_2) def test_create_key_with_length(self): for length in [64, 128, 256]: key_id = self.key_mgr.create_key(self.context, length=length) key = self.key_mgr.get(self.context, key_id) self.assertEqual(length / 8, len(key.get_encoded())) self.assertIsNotNone(key.id) def test_create_key_with_name(self): name = 'my key' key_id = self.key_mgr.create_key(self.context, name=name) key = self.key_mgr.get(self.context, key_id) self.assertEqual(name, key.name) self.assertIsNotNone(key.id) def test_create_key_with_algorithm(self): algorithm = 'DES' key_id = self.key_mgr.create_key(self.context, algorithm=algorithm) key = self.key_mgr.get(self.context, key_id) self.assertEqual(algorithm, key.algorithm) self.assertIsNotNone(key.id) def test_create_key_null_context(self): self.assertRaises(exception.Forbidden, self.key_mgr.create_key, None) def test_create_key_pair(self): for length in [2048, 3072, 4096]: name = str(length) + ' key' private_key_uuid, public_key_uuid = self.key_mgr.create_key_pair( self.context, 'RSA', length, name=name) private_key = self.key_mgr.get(self.context, private_key_uuid) self.assertIsNotNone(private_key.id) public_key = self.key_mgr.get(self.context, public_key_uuid) self.assertIsNotNone(public_key.id) crypto_private_key = get_cryptography_private_key(private_key) crypto_public_key = get_cryptography_public_key(public_key) self.assertEqual(name, private_key.name) self.assertEqual(name, public_key.name) self.assertEqual(length, crypto_private_key.key_size) self.assertEqual(length, crypto_public_key.key_size) def test_create_key_pair_encryption(self): private_key_uuid, public_key_uuid = self.key_mgr.create_key_pair( self.context, 'RSA', 2048) private_key = self.key_mgr.get(self.context, private_key_uuid) public_key = self.key_mgr.get(self.context, public_key_uuid) crypto_private_key = get_cryptography_private_key(private_key) crypto_public_key = get_cryptography_public_key(public_key) message = b'secret plaintext' ciphertext = crypto_public_key.encrypt( message, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), label=None)) plaintext = crypto_private_key.decrypt( ciphertext, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), label=None)) self.assertEqual(message, plaintext) def test_create_key_pair_null_context(self): self.assertRaises(exception.Forbidden, self.key_mgr.create_key_pair, None, 'RSA', 2048) def test_create_key_pair_invalid_algorithm(self): self.assertRaises(ValueError, self.key_mgr.create_key_pair, self.context, 'DSA', 2048) def test_create_key_pair_invalid_length(self): self.assertRaises(ValueError, self.key_mgr.create_key_pair, self.context, 'RSA', 10) def test_store_and_get_key(self): secret_key = bytes(b'0' * 64) _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key) key_id = self.key_mgr.store(self.context, _key) actual_key = self.key_mgr.get(self.context, key_id) self.assertEqual(_key, actual_key) self.assertIsNotNone(actual_key.id) def test_store_key_and_get_metadata(self): secret_key = bytes(b'0' * 64) _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key) key_id = self.key_mgr.store(self.context, _key) actual_key = self.key_mgr.get(self.context, key_id, metadata_only=True) self.assertIsNone(actual_key.get_encoded()) self.assertTrue(actual_key.is_metadata_only()) self.assertIsNotNone(actual_key.id) def test_store_key_and_get_metadata_and_get_key(self): secret_key = bytes(b'0' * 64) _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key) key_id = self.key_mgr.store(self.context, _key) actual_key = self.key_mgr.get(self.context, key_id, metadata_only=True) self.assertIsNone(actual_key.get_encoded()) self.assertTrue(actual_key.is_metadata_only()) actual_key = self.key_mgr.get(self.context, key_id, metadata_only=False) self.assertIsNotNone(actual_key.get_encoded()) self.assertFalse(actual_key.is_metadata_only()) self.assertIsNotNone(actual_key.id) def test_store_null_context(self): self.assertRaises(exception.Forbidden, self.key_mgr.store, None, None) def test_get_null_context(self): self.assertRaises(exception.Forbidden, self.key_mgr.get, None, None) def test_get_unknown_key(self): self.assertRaises(KeyError, self.key_mgr.get, self.context, None) def test_delete_key(self): key_id = self.key_mgr.create_key(self.context) self.key_mgr.delete(self.context, key_id) self.assertRaises(KeyError, self.key_mgr.get, self.context, key_id) def test_delete_null_context(self): self.assertRaises(exception.Forbidden, self.key_mgr.delete, None, None) def test_delete_unknown_key(self): self.assertRaises(KeyError, self.key_mgr.delete, self.context, None) def test_list_null_context(self): self.assertRaises(exception.Forbidden, self.key_mgr.list, None) def test_list_keys(self): key1 = sym_key.SymmetricKey('AES', 64 * 8, bytes(b'0' * 64)) self.key_mgr.store(self.context, key1) key2 = sym_key.SymmetricKey('AES', 32 * 8, bytes(b'0' * 32)) self.key_mgr.store(self.context, key2) keys = self.key_mgr.list(self.context) self.assertEqual(2, len(keys)) self.assertTrue(key1 in keys) self.assertTrue(key2 in keys) for key in keys: self.assertIsNotNone(key.id) def test_list_keys_metadata_only(self): key1 = sym_key.SymmetricKey('AES', 64 * 8, bytes(b'0' * 64)) self.key_mgr.store(self.context, key1) key2 = sym_key.SymmetricKey('AES', 32 * 8, bytes(b'0' * 32)) self.key_mgr.store(self.context, key2) keys = self.key_mgr.list(self.context, metadata_only=True) self.assertEqual(2, len(keys)) bit_length_list = [key1.bit_length, key2.bit_length] for key in keys: self.assertTrue(key.is_metadata_only()) self.assertTrue(key.bit_length in bit_length_list) for key in keys: self.assertIsNotNone(key.id)