import json import os.path import platform import socket from base64 import b64decode, b64encode from pathlib import Path import pysodium from .exceptions import ProtocolError BUFF_SIZE = 1024 * 1024 DEFAULT_SOCKET_TIMEOUT = 60 DEFAULT_SOCKET_NAME = 'kpxc_server' def create_keypair(): """Return (public key, private key)""" return pysodium.crypto_box_keypair() def create_nonce(): return pysodium.randombytes(pysodium.crypto_box_NONCEBYTES) def create_public_key(): return pysodium.randombytes(pysodium.crypto_box_PUBLICKEYBYTES) def create_nonces(nonce=None, next_nonce=None): if nonce is None: nonce = create_nonce() assert next_nonce is None, next_nonce if next_nonce is None: next_nonce = increment_nonce(nonce) return nonce, next_nonce def increment_nonce(nonce): next_nonce = list(nonce) assert isinstance(nonce, bytes) c_state = 1 for i, x in enumerate(next_nonce): c_state += x c_state %= 256 next_nonce[i] = c_state c_state >>= 8 return bytes(next_nonce) def encrypt(message, nonce, serverKey, secretKey): return pysodium.crypto_box(message, nonce, serverKey, secretKey) def decrypt(message, nonce, serverKey, secretKey): return pysodium.crypto_box_open(message, nonce, serverKey, secretKey) def binary_to_b64(binary): assert isinstance(binary, bytes), binary return b64encode(binary).decode() def binary_from_b64(s): assert isinstance(s, str), s return b64decode(s.encode()) def check_nonces(response, expected_nonce): assert isinstance(response, dict), response nonce_key = 'nonce' assert nonce_key in response, repr(response) response_nonce = binary_from_b64(response[nonce_key]) assert response_nonce == expected_nonce def create_command(action, **data): command = {"action": action, "triggerUnlock": 'true'} command.update(data) return command def create_message(action, **data): command = {"action": action, "triggerUnlock": 'true'} command.update(data) return command def create_encrypted_command(crypto, action, message): nonce = create_nonce() command = create_command( action, message=binary_to_b64(crypto.encrypt_message(message, nonce)) ) return command, nonce class Connection: def __init__(self): # TODO: darwin is untested tmpdir = os.getenv('TMPDIR') if tmpdir: tmpdir = Path(tmpdir) tmpdir_socket_path = tmpdir / DEFAULT_SOCKET_NAME xdg_runtime_dir = os.getenv('XDG_RUNTIME_DIR') if xdg_runtime_dir: xdg_runtime_dir = Path(xdg_runtime_dir) runtime_socket_path = xdg_runtime_dir / DEFAULT_SOCKET_NAME if platform.system() == "Darwin" and tmpdir and tmpdir_socket_path.exists(): server_address = tmpdir_socket_path elif xdg_runtime_dir and runtime_socket_path.exists(): server_address = runtime_socket_path # TODO: tmpdir is untested elif tmpdir and tmpdir_socket_path.exists(): server_address = tmpdir_socket_path else: raise OSError('Unknown path for keepassxc socket.') self.server_address = server_address self.sock = None def connect(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(DEFAULT_SOCKET_TIMEOUT) sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, BUFF_SIZE) sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFF_SIZE) try: sock.connect(str(self.server_address)) except socket.error: sock.close() raise Exception( "Could not connect to {addr}".format(addr=self.server_address) ) self.sock = sock def disconnect(self): self.sock.close() def send(self, command): assert isinstance(command, str) self.sock.send(command.encode()) resp, server = self.sock.recvfrom(BUFF_SIZE) r = resp.decode() return r def send_json(self, command): return json.loads(self.send(json.dumps(command))) def send_command(self, identity, command, nonce=None, next_nonce=None): nonce, next_nonce = create_nonces(nonce, next_nonce) identity.sign_command(command, nonce) resp = self.send_json(command) if 'error' in resp: raise ProtocolError(resp['error']) check_nonces(resp, next_nonce) return resp def send_encrypted_command(self, identity, command, nonce=None, next_nonce=None): nonce, next_nonce = create_nonces(nonce, next_nonce) resp = self.send_command(identity, command, nonce, next_nonce) resp_message = identity.decrypt_message(resp['message'], next_nonce) return resp_message def encrypt_message_send_command(self, identity, action, message): command, nonce = create_encrypted_command(identity, action, message) return self.send_encrypted_command(identity, command, nonce) def change_public_keys(self, identity): nonce, next_nonce = create_nonces() command = create_command( 'change-public-keys', publicKey=binary_to_b64(identity.publicKey) ) resp = self.send_command(identity, command) assert 'publicKey' in resp, resp server_public_key = binary_from_b64(resp['publicKey']) identity.serverPublicKey = server_public_key def get_database_hash(self, identity): action = 'get-databasehash' message = create_message(action) resp_message = self.encrypt_message_send_command(identity, action, message) return resp_message['hash'] def associate(self, identity): action = 'associate' message = create_message( action, key=binary_to_b64(identity.publicKey), idKey=binary_to_b64(identity.associated_id_key), ) resp_message = self.encrypt_message_send_command(identity, action, message) assert 'id' in resp_message associated_name = resp_message['id'] identity.associated_name = associated_name return associated_name def test_associate(self, identity): action = 'test-associate' assert identity.associated_id_key is not None, identity.associated_id_key message = create_message( action, id=identity.associated_name, key=binary_to_b64(identity.associated_id_key), ) try: self.encrypt_message_send_command(identity, action, message) except ProtocolError: return False return True def create_password(self, identity): action = 'generate-password' command = create_command(action) nonce = create_nonce() resp = self.send_encrypted_command(identity, command, nonce) assert 'entries' in resp entries = resp['entries'] assert len(entries) == 1, resp entry = entries[0] return entry['login'], entry['password'] def get_logins(self, identity, url, submit_url=None, http_auth=None): action = 'get-logins' message = create_message( action, id=identity.associated_name, url=url, keys=[ dict( id=identity.associated_name, key=binary_to_b64(identity.associated_id_key), ) ], ) if submit_url: message['submitUrl'] = submit_url if http_auth: message['httpAuth'] = http_auth resp_message = self.encrypt_message_send_command(identity, action, message) return resp_message['entries'] def set_login( self, identity, url, login=None, password=None, entry_id=None, submit_url=None ): if not (url.startswith('mailto:') or url.startswith('https:')): raise Exception('Url needs to start with "mailto:" or "https:"') action = 'set-login' message = create_message(action, id=identity.associated_name, url=url) for k in 'login password entry_id submit_url'.split(): v = locals()[k] if v is not None: message[k] = v resp_message = self.encrypt_message_send_command(identity, action, message) assert resp_message['success'] def lock_database(self, identity): action = 'lock-database' message = create_message(action) resp_message = self.encrypt_message_send_command(identity, action, message) assert resp_message['success'] def is_database_open(self, identity): # Yeah, that's really hacky, FIXME when https://github.com/keepassxreboot/keepassxc-browser/issues/594 is closed try: self.get_database_hash(identity) return True except ProtocolError: return False def wait_for_unlock(self): """ This will listen to all messages until {'action': 'database-unlocked'} is received. If the database is already open, it will wait until it is unlocked the next time. This will not time out. If the database was unlocked while connected, and this method is called afterwards, it will return even if the database has been closed again in the meantime. """ while True: try: action = json.loads(self.sock.recv(BUFF_SIZE).decode())['action'] if action == "database-unlocked": break except socket.timeout: pass class Identity: VERSION = 1 VERSION_KEY = 'version' BINARY_KEY = 'binary' TEXT_KEY = 'text' def __init__( self, client_id, id_key=None, associated_name=None, ): self.client_id = client_id public_key, private_key = create_keypair() if not id_key: id_key = create_public_key() self.publicKey = public_key self.secretKey = private_key self.associated_id_key = id_key self.associated_name = associated_name self.serverPublicKey = None def sign_command(self, command, nonce): command.setdefault('nonce', binary_to_b64(nonce)) command.setdefault('clientID', self.client_id) def encrypt_message(self, message, nonce): message = json.dumps(message) message = message.encode() assert self.serverPublicKey message = encrypt(message, nonce, self.serverPublicKey, self.secretKey) return message def decrypt_message(self, resp_message, expected_nonce): resp_message = binary_from_b64(resp_message) resp_message = decrypt( resp_message, expected_nonce, self.serverPublicKey, self.secretKey ) resp_message = json.loads(resp_message) check_nonces(resp_message, expected_nonce) return resp_message def serialize(self): binary_data = (self.associated_id_key,) text_data = (self.associated_name,) binary_data = [binary_to_b64(d) for d in binary_data] s = json.dumps({ self.VERSION_KEY: self.VERSION, self.BINARY_KEY: list(binary_data), self.TEXT_KEY: list(text_data), }) return s @classmethod def unserialize(cls, client_id, s): data = json.loads(s) if isinstance(data, list): return cls.unserialize_v0(client_id, data) unserializers = { 1: cls.unserialize_v1, } version = data[cls.VERSION_KEY] assert version in unserializers, 'unknown version %s' % version return unserializers[version](client_id, data) @classmethod def unserialize_v1(cls, client_id, data): binary_data = data[cls.BINARY_KEY] binary_data = [binary_from_b64(d) for d in binary_data] text_data = data[cls.TEXT_KEY] (id_key,) = binary_data (associated_name,) = text_data return cls( client_id=client_id, id_key=id_key, associated_name=text_data[0], ) @classmethod def unserialize_v0(cls, client_id, data): """The first version unserialize, maintained for backwards compatability.""" assert isinstance(data, list) BINARY_SIZE = 4 TEXT_SIZE = 1 DATA_SIZE = BINARY_SIZE + TEXT_SIZE assert len(data) == DATA_SIZE, data binary_data = data[:BINARY_SIZE] text_data = data[BINARY_SIZE:] binary_data = [binary_from_b64(d) for d in binary_data] public_key, private_key, id_key, server_public_key = binary_data (associated_name,) = text_data # public_key, private_key, server_public_key ignored, will be regenerated every time return cls( client_id=client_id, id_key=id_key, associated_name=associated_name, )