# Copyright 2017 rpaas authors. All rights reserved. # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. import datetime import time import unittest import redis import consul from freezegun import freeze_time from mock import patch, call from rpaas import storage, tasks from rpaas import session_resumption, consul_manager from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID tasks.app.conf.CELERY_ALWAYS_EAGER = True class LoadBalancerFake(object): def __init__(self, name): self.name = name self.hosts = [] class HostFake(object): def __init__(self, id, group, dns_name): self.id = id self.group = group self.dns_name = dns_name self.fail_property = None def set_fail(self, name): self.fail_property = name def unset_fail(self, name): self.fail_property = None def __getattribute__(self, name): fail_property = object.__getattribute__(self, "fail_property") if fail_property and fail_property == name: raise AttributeError("{} not defined".format(name)) return object.__getattribute__(self, name) @freeze_time("2016-02-03 12:00:00") class SessionResumptionTestCase(unittest.TestCase): @classmethod def setUpClass(cls): cls.ca_key, cls.ca_cert = cls.generate_ca() @classmethod def generate_ca(cls): key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, u"BR"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"RJ"), x509.NameAttribute(NameOID.LOCALITY_NAME, u"Rio de Janeiro"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Tsuru Inc"), x509.NameAttribute(NameOID.COMMON_NAME, u"tsuru.io"), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( issuer ).public_key( key.public_key() ).serial_number( x509.random_serial_number() ).not_valid_before( datetime.datetime.utcnow() ).not_valid_after( datetime.datetime.utcnow() + datetime.timedelta(days=10) ).add_extension( x509.SubjectAlternativeName([x509.DNSName(u"tsuru.io")]), critical=False, ).sign(key, hashes.SHA256(), default_backend()) key = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) cert = cert.public_bytes(serialization.Encoding.PEM) return key, cert def setUp(self): self.master_token = "rpaas-test" self.config = { "CONSUL_HOST": "127.0.0.1", "CONSUL_TOKEN": self.master_token, "MONGO_DATABASE": "session_resumption_test", "RPAAS_SERVICE_NAME": "test_rpaas_session_resumption", "HOST_MANAGER": "fake", "SESSION_RESUMPTION_RUN_INTERVAL": 2, u"CA_CERT": unicode(self.ca_cert), u"CA_KEY": unicode(self.ca_key) } self.consul = consul.Consul(token=self.master_token) self.consul.kv.delete("test_rpaas_session_resumption", recurse=True) self.storage = storage.MongoDBStorage(self.config) self.consul_manager = consul_manager.ConsulManager(self.config) colls = self.storage.db.collection_names(False) for coll in colls: self.storage.db.drop_collection(coll) redis.StrictRedis().flushall() @patch("rpaas.tasks.sslutils.generate_session_ticket") @patch("rpaas.tasks.LoadBalancer") @patch("rpaas.tasks.nginx") def test_renew_session_tickets(self, nginx, load_balancer, ticket): nginx_manager = nginx.Nginx.return_value lb1 = LoadBalancerFake("instance-a") lb2 = LoadBalancerFake("instance-b") lb1.hosts = [HostFake("xxx", "instance-a", "10.1.1.1"), HostFake("yyy", "instance-a", "10.1.1.2")] lb2.hosts = [HostFake("aaa", "instance-b", "10.2.2.2"), HostFake("bbb", "instance-b", "10.2.2.3")] load_balancer.list.return_value = [lb1, lb2] ticket.side_effect = ["ticket1", "ticket2", "ticket3", "ticket4"] session = session_resumption.SessionResumption(self.config) session.start() time.sleep(1) session.stop() nginx_expected_calls = [call('10.1.1.1', 'ticket1', 30), call('10.1.1.2', 'ticket1', 30), call('10.2.2.2', 'ticket2', 30), call('10.2.2.3', 'ticket2', 30)] self.assertEqual(nginx_expected_calls, nginx_manager.add_session_ticket.call_args_list) cert_a, key_a = self.consul_manager.get_certificate("instance-a", "xxx") cert_b, key_b = self.consul_manager.get_certificate("instance-b", "bbb") redis.StrictRedis().delete("session_resumption:test_rpaas_session_resumption:last_run") nginx_manager.reset_mock() session = session_resumption.SessionResumption(self.config) session.start() time.sleep(1) session.stop() nginx_expected_calls = [call('10.1.1.1', 'ticket3', 30), call('10.1.1.2', 'ticket3', 30), call('10.2.2.2', 'ticket4', 30), call('10.2.2.3', 'ticket4', 30)] self.assertEqual(nginx_expected_calls, nginx_manager.add_session_ticket.call_args_list) self.assertTupleEqual((cert_a, key_a), self.consul_manager.get_certificate("instance-a", "xxx")) self.assertTupleEqual((cert_b, key_b), self.consul_manager.get_certificate("instance-b", "bbb")) @patch.object(tasks.SessionResumptionTask, "rotate_session_ticket", return_value=None) @patch("rpaas.tasks.LoadBalancer") def test_renew_session_tickets_only_on_selected_instances(self, load_balancer, rotate_session): self.config["SESSION_RESUMPTION_INSTANCES"] = "instance-a,instance-c" lb1 = LoadBalancerFake("instance-a") lb2 = LoadBalancerFake("instance-b") lb3 = LoadBalancerFake("instance-c") lb4 = LoadBalancerFake("instance-d") lb1.hosts = [HostFake("xxx", "instance-a", "10.1.1.1")] lb2.hosts = [HostFake("yyy", "instance-b", "10.2.1.1")] lb3.hosts = [HostFake("aaa", "instance-c", "10.3.2.2")] lb4.hosts = [HostFake("bbb", "instance-d", "10.4.2.2")] load_balancer.list.return_value = [lb1, lb2, lb3, lb4] session = session_resumption.SessionResumption(self.config) session.start() time.sleep(1) session.stop() self.assertEqual(rotate_session.call_args_list, [call(lb1.hosts), call(lb3.hosts)]) @patch("rpaas.tasks.logging") @patch("rpaas.tasks.sslutils.generate_session_ticket") @patch("rpaas.tasks.LoadBalancer") @patch("rpaas.tasks.nginx") def test_renew_session_tickets_fail_and_unlock(self, nginx, load_balancer, ticket, logging): nginx_manager = nginx.Nginx.return_value lb1_host2 = HostFake("yyy", "instance-a", "10.1.1.2") lb1_host2.set_fail("dns_name") lb1 = LoadBalancerFake("instance-a") lb2 = LoadBalancerFake("instance-b") lb1.hosts = [HostFake("xxx", "instance-a", "10.1.1.1"), lb1_host2] lb2.hosts = [HostFake("aaa", "instance-b", "10.2.2.2"), HostFake("bbb", "instance-b", "10.2.2.3")] load_balancer.list.return_value = [lb1, lb2] ticket.side_effect = ["ticket1", "ticket2", "ticket3", "ticket4"] session = session_resumption.SessionResumption(self.config) session.start() time.sleep(1) session.stop() nginx_expected_calls = [call('10.1.1.1', 'ticket1', 30), call('10.2.2.2', 'ticket2', 30), call('10.2.2.3', 'ticket2', 30)] self.assertEqual(nginx_expected_calls, nginx_manager.add_session_ticket.call_args_list) redis.StrictRedis().delete("session_resumption:test_rpaas_session_resumption:last_run") lb1_host2.unset_fail("dns_name") nginx_manager.reset_mock() session = session_resumption.SessionResumption(self.config) session.start() time.sleep(1) session.stop() nginx_expected_calls = [call('10.1.1.1', 'ticket3', 30), call('10.1.1.2', 'ticket3', 30), call('10.2.2.2', 'ticket4', 30), call('10.2.2.3', 'ticket4', 30)] self.assertEqual(nginx_expected_calls, nginx_manager.add_session_ticket.call_args_list) error_msg = "Error renewing session ticket for instance-a: AttributeError('dns_name not defined',)" logging.error.assert_called_with(error_msg) @patch("rpaas.tasks.logging") @patch("rpaas.tasks.sslutils.generate_admin_crt") @patch("rpaas.tasks.LoadBalancer") @patch("rpaas.tasks.nginx") def test_renew_session_tickets_return_first_error(self, nginx, load_balancer, generate_cert, logging): nginx_manager = nginx.Nginx.return_value lb1 = LoadBalancerFake("instance-a") lb1.hosts = [HostFake("xxx", "instance-a", "10.1.1.1")] load_balancer.list.return_value = [lb1] generate_cert.side_effect = Exception("could not generate certificate") nginx_manager.add_session_ticket.side_effect = Exception("nginx error connecting to host") session = session_resumption.SessionResumption(self.config) session.start() time.sleep(1) session.stop() error_msg = "Error renewing session ticket for instance-a: " \ "Exception('could not generate certificate',)" logging.error.assert_called_with(error_msg) nginx_manager.add_session_ticket.assert_not_called()