# This Source Code Form is subject to the terms of the MIT License. # If a copy of the MIT License was not distributed with this # file, you can obtain one at https://opensource.org/licenses/MIT. # """Access to the HOTP data for server client pairs. Represents the data for a particular HOTP server and client pair. Stores the client data in a file. Loads client data from a file. Encrypts data using supply passphrase. This implementation requires: * Python 3.5 or later * cryptography 1.3 or later (see https://cryptography.io/en/latest/) * python-dateutil 2.1 or later (see https://pypi.python.org/pypi/python-dateutil/2.1) * six 1.10 or later (https://pypi.python.org/pypi/six/1.10.0) """ import json class DecryptionError(Exception): """Failed to decrypt the data.""" pass class FileCorruptionError(Exception): """HOTP data file is corrupted, unreadable.""" pass class ClientDataDecoder(json.JSONDecoder): """A JSONDecoder that recognizes ClientData objects in a JSON string.""" def __init__(self, **kw_args): """Compose the standard JSONDecoder with a custom object_hook. The custom object_hook will recognize a dictionary that represents a ClientData object, and decode it as a ClientData object. All other objects will get passed to the standard JSONDecoder. Args: Same arguments as JSONDecoder.__init__() with the exception that 'strict' is always set to False. If an 'object_hook' is supplied then it will be called by _object_decode() if the object is not interpreted as ClientData. """ self._other_object_hook = None kw_args_new = kw_args.copy() if 'object_hook' in kw_args: self._other_object_hook = kw_args['object_hook'] kw_args_new['object_hook'] = self._object_decode # Note: strict=False because the notes attribute might contain # line feeds. # kw_args_new['strict'] = False self._decoder = json.JSONDecoder(**kw_args_new) def _object_decode(self, d): """Convert decoded JSON to a ClientData object. Take the object decoded from the JSON and if it corresponds to ClientData objects, convert it to a ClientData object. Returns: If converted, returns a ClientData object. Otherwise, returns the original object d. """ if ((isinstance(d, dict)) and ('clientId' in d)): cd = ClientData(**d) return cd elif self._other_object_hook is not None: return self._other_object_hook(d) else: return d def decode(self, s): """Inoke the decode method of encapsulated decoder. Invoke the decode() method of the encapsulated decoder (which has an object_hook). Returns: The Python representation of 's'. """ o = self._decoder.decode(s) return o class ClientDataEncoder(json.JSONEncoder): """A specialized JSONEncoder that handles ClassData objects. Specialize the standard JSONEncoder class to detect ClassData objects and convert them to a standard object type that the JSONEncoder can handle. """ def default(self, o): """Detect and convert ClassData objects. Detect ClassData objects and convert them to dictionaries. If not ClassData then invoke the superclass default() method. """ if (ClientData.__name__ == o.__class__.__name__): return o.to_dict() else: return json.JSONEncoder.default(self, o) class ClientData: """Represents a HOTP configuration from the client point of view.""" # -------------------------------------------------------------------------+ # class attributes # -------------------------------------------------------------------------+ __tz = None __utz = None # -------------------------------------------------------------------------+ # static methods # -------------------------------------------------------------------------+ @staticmethod def utz(): """UTC time zone.""" from datetime import timezone, timedelta if ClientData.__utz is None: ClientData.__utz = timezone(timedelta(0)) return ClientData.__utz @staticmethod def tz(): """Local time zone.""" from datetime import datetime, timezone, timedelta if ClientData.__tz is None: # figure the local timezone # lt = datetime.now() ut = datetime.utcnow() lt2 = datetime.now() if ut.second == lt2.second: lt = lt2 # Strip off the microseconds, or the deltatime won't be in # round seconds # lt = datetime( lt.year, lt.month, lt.day, lt.hour, lt.minute, lt.second) ut = datetime( ut.year, ut.month, ut.day, ut.hour, ut.minute, ut.second) # Get UTC offset as a timedelta object # dt = ut - lt # Get UTC offset in minutes # offset_minutes = 0 if (0 == dt.days): offset_minutes = dt.seconds // 60 else: dt = lt - ut offset_minutes = dt.seconds // 60 offset_minutes *= -1 dt = timedelta(minutes=offset_minutes) ClientData.__tz = timezone(dt) return ClientData.__tz # ------------------------------------------------------------------------+ # insternal methods # ------------------------------------------------------------------------+ def _init_client_id(self, kw_args): """Process kw_arg client_id.""" if 'clientId' not in kw_args: raise ValueError("Need a clientId string.") self.__client_id = kw_args['clientId'] if not isinstance(self.__client_id, str): raise TypeError("clientId must be a string.") if 0 == len(self.__client_id): raise ValueError("clientId must be a non-empty string.") def _init_shared_secret(self, kw_args): """Process kw_arg shared_secret.""" if 'sharedSecret' not in kw_args: raise ValueError("Need a sharedSecret string.") self.__shared_secret = kw_args['sharedSecret'] if not ((isinstance(self.__shared_secret, str)) or (isinstance(self.__shared_secret, bytes))): raise TypeError( "sharedSecret must be a string or byte string.") if 0 == len(self.__shared_secret): raise ValueError( "sharedSecret must be a non-empty string or byte string.") def _init_counter_from_time(self, kw_args): """Process kw_arg counter_from_time.""" self.__counter_from_time = True if 'counterFromTime' in kw_args: if not kw_args['counterFromTime']: self.__counter_from_time = False def _init_last_count(self, kw_args): """Process kw_arg last_count.""" self.__last_count = 0 if 'lastCount' in kw_args: self.__last_count = int(kw_args['lastCount']) if 0 > self.__last_count: raise ValueError( "lastCount must be zero or a positive integer") def _init_last_count_update_time(self, kw_args): """Process kw_arg kw_arg last_count_update_time.""" from datetime import datetime import iso8601 self.__last_count_update_time = datetime( 1, 1, 1, 0, 0, 0, 0, ClientData.utz()).strftime(self._isoFmt) # Fix issue on some systems, e.g. Debian, where %Y doesn't zero-pad if self.__last_count_update_time[0:3] != "000": self.__last_count_update_time = "000" + \ self.__last_count_update_time if 'lastCountUpdateTime' in kw_args: t = datetime.min v = kw_args['lastCountUpdateTime'] if isinstance(v, datetime): t = v elif isinstance(v, str): t = iso8601.parse_date(v) else: raise TypeError( "lastCountUpdateTime must be datetime object" " or a datetime string") if t.tzinfo is None: t = t.replace(tzinfo=ClientData.utz()) self.__last_count_update_time = t.strftime(self._isoFmt) # Fix issue on some systems, e.g. Debian, where %Y doesn't zero-pad tpadding = "" if 10 > t.year: tpadding = "000" elif 100 > t.year: tpadding = "00" elif 1000 > t.year: tpadding = "0" if "0" != self.__last_count_update_time[0:1]: self.__last_count_update_time = tpadding + \ self.__last_count_update_time def _init_period(self, kw_args): """Process kw_arg period.""" self.__period = 30 if 'period' in kw_args: p = int(kw_args['period']) if (0 >= p): raise ValueError("period must be a positive integer") self.__period = p def _init_password_length(self, kw_args): """Process kw_arg password_length.""" self.__password_length = 6 if 'passwordLength' in kw_args: pwd_len = int(kw_args['passwordLength']) if ((1 > pwd_len) or (10 < pwd_len)): raise ValueError("passwordLength must be in the range [1,10]") self.__password_length = pwd_len def _init_tags(self, kw_args): """Process kw_arg tags.""" self.__tags = [] if 'tags' in kw_args: if (isinstance(kw_args['tags'], tuple) or isinstance(kw_args['tags'], list)): for tag in kw_args['tags']: if not isinstance(tag, str): raise TypeError( "tags must be a sequence of string values") if 0 < len(tag): self.__tags.append(tag) elif isinstance(kw_args['tags'], str): if 0 < len(kw_args['tags']): self.__tags.append(kw_args['tags']) else: raise TypeError("tags must be a sequence of string values") def _init_note(self, kw_args): """Process kw_arg note.""" self.__note = "" if 'note' in kw_args: if not isinstance(kw_args['note'], str): raise TypeError("note must be a string") self.__note = kw_args['note'] # ------------------------------------------------------------------------+ # dunder methods # ------------------------------------------------------------------------+ def __init__(self, **kw_args): """Constructor for ClientData object. Used by ClientDataDecoder, a JSONDecoder. Args: clientId: Required. A string to identify the client and server combination that this ClientData object represents. For example, an Amazon Web Services account 12345654321 and user what.me.worry might be identified by a client id of "12345654321@what.me.worry". sharedSecret: Required. The shared secret provided by the server when the HOTP configuration was created for the client. This is either a Base32 encoded string representing a byte string, or it is the byte string itself. counterFromTime: Default True. Whether to use a counter-based HOTP or to use a time-based HOTP. lastCount: Default 0. The counter used in the most recent counter-based HOTP calculation. Must be zero or a postive integer. lastCountUpdateTime: Default datetime.min. The time that the last_count was most recently changed. period: If time-based HOTP, default is 30; otherwise default is 0. This is the number of seconds in the period for which the HOTP is calculated. In the range (0, +infinity]. passwordLength: Default is 6. The number of digits in the HOTP string. Must be in the range [1,10] tags: Default is an empty list. This is a list of strings that can be used to filter a collection of ClientData objects. note: Default is an empty string. This is just a freeform text field in which any notes about the client server HOTP combination can be supplied. """ self._isoFmt = "%Y%m%dT%H%M%S%z" self._init_client_id(kw_args) self._init_shared_secret(kw_args) self._init_counter_from_time(kw_args) self._init_last_count(kw_args) self._init_last_count_update_time(kw_args) self._init_period(kw_args) self._init_password_length(kw_args) self._init_tags(kw_args) self._init_note(kw_args) def __str__(self): """Stringify this object.""" result = [] result.append("client_id: '{0}'".format(self.__client_id)) result.append("shared_secret: '{0}'".format(self.__shared_secret)) result.append( "counter_from_time: {0}".format(self.__counter_from_time)) result.append("last_count: {0}".format(self.__last_count)) result.append("last_count_update_time: {0}".format( self.__last_count_update_time)) result.append("period: {0}".format(self.__period)) result.append("password_length: {0}".format(self.__password_length)) result.append("tags: {0}".format(self.__tags)) result.append("note: \"\"\"{0}\"\"\"".format(self.__note)) return "\n".join(result) def __eq__(self, other): """Whether this object is equal to the other.""" if type(self) != type(other): return False s_vars = vars(self) o_vars = vars(other) for v in vars(self): if s_vars[v] != o_vars[v]: print("unequal property {0}\n".format(v)) if v.endswith("last_count_update_time"): print("self: {0}\n".format(s_vars[v])) print("othr: {0}\n".format(o_vars[v])) return False return True def __ne__(self, other): """Whether this object is not equal to the other.""" if type(self) != type(other): return True s_vars = vars(self) o_vars = vars(other) for v in vars(self): if s_vars[v] != o_vars[v]: return True return False def __repr__(self): """Canonical string representation of this object.""" result = [] result.append( "ClientData(client_id='{0}',".format(self.__client_id)) result.append( "shared_secret='{0}',".format(self.__shared_secret)) result.append( "last_count_update_time='{0}')".format( self.__last_count_update_time)) return ' '.join(result) # -------------------------------------------------------------------------+ # properties # -------------------------------------------------------------------------+ def client_id(self): """Get the string that identifies the client and server combination. Returns: The client_id as a string. """ return self.__client_id def shared_secret(self): """Get the shared secret used to calculate the HOTP. Returns: The shared_secret as a byte string. """ return self.__shared_secret def counter_from_time(self): """Whether HOTP is determined from current time. Get whether the HOTP is calculated with a counter determined from the current time. Returns: True or False. """ return self.__counter_from_time def incremented_count(self): """Increment the counter and return the new value. Will update last_count() and last_count_update_time() properties. Only relevant if counter_from_time() is True. Returns: The incremented last_count value. """ from datetime import datetime self.__last_count += 1 # get the local time, with timezone # now = datetime.now(ClientData.tz()) self.set_last_count_update_time(now) return self.last_count() def last_count(self): """Get the counter value from last counter-based HOTP calculation. Only relevant if counter_from_time() is False. Returns: The last_count integer value. """ return self.__last_count def last_count_update_time(self): """Get the timestamp of the last counter-based HOTP calculation. Only relevant if counter_from_time() is False. Returns: The last_count_update_time datetime value. """ return self.__last_count_update_time def set_last_count_update_time(self, update_time): """Set the timestamp of the last counter-based HOTP calculation. Only relevant if counter_from_time() is False. Args: update_time: either a datetime object (preferably with a timezone), or a string with time in ISO format "%Y%m%dT%H%M%S%z" """ from datetime import datetime if isinstance(update_time, datetime): self.__last_count_update_time = update_time.strftime(self._isoFmt) # Fix issue on some systems, e.g. Debian, where %Y doesn't zero-pad tpadding = "" if 10 > update_time.year: tpadding = "000" elif 100 > update_time.year: tpadding = "00" elif 1000 > update_time.year: tpadding = "0" if "0" != self.__last_count_update_time[0:1]: self.__last_count_update_time = tpadding + \ self.__last_count_update_time else: self.__last_count_update_time = update_time def period(self): """The period of the time-based counter used in the HOTP calculation. Only relevant if counter_from_time() is True. """ return self.__period def password_length(self): """The length of the HOTP code to be generated (e.g. 6).""" return self.__password_length def tags(self): """List of tag strings.""" return self.__tags[:] def note(self): """Freeform note text.""" return self.__note # -------------------------------------------------------------------------+ # serialization (for JSON) # -------------------------------------------------------------------------+ def to_dict(self): """Represent object as a key-value collection. Used by ClientDataEncoder, a JSONEncoder. """ d = {'clientId': self.__client_id} d.update({'sharedSecret': self.__shared_secret}) d.update({'counterFromTime': self.__counter_from_time}) d.update({'lastCount': self.__last_count}) d.update({'lastCountUpdateTime': self.__last_count_update_time}) d.update({'period': self.__period}) d.update({'passwordLength': self.__password_length}) d.update({'tags': self.__tags}) d.update({'note': self.__note}) return d class ClientFile: """Support persistence of ClientData objects in encrypted file. Encapsulates the work needed to persist a collection of ClientData objects using an encrypted file. """ def __init__(self, passphrase): """Create a ClientFile object. Args: passphrase: The passphrase string used to encrypt and decrypt the data file. """ self.__key_stretches = 256 * 1024 self.__magic_number = 0x7A6A5A4A self.__file_version = 1 self.__key = self._produce_key(passphrase) self.__iv = self._produce_iv(self.__key) return # -------------------------------------------------------------------------+ # internal properties # -------------------------------------------------------------------------+ def _get_key_stretches(self): """Get count of hash iterations used to slow key generation. Get the number of hash iterations used to stretch key generation (to defeat brute force key cracking). """ return self.__key_stretches # -------------------------------------------------------------------------+ # internal methods # -------------------------------------------------------------------------+ def _decrypt(self, b, strip_padding=True): """Decrypt a byte string. Uses the AES 256-bit symmetric key cypher. Args: b: the byte string to decrypt. strip_padding: whether to remove the padding (padding is required by AES2 to make the encrypted data an exact multiple of 16 bytes in length). Returns: The decrypted data as a byte string. """ from cryptography.hazmat.primitives.ciphers \ import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend backend = default_backend() cypher = Cipher( algorithms.AES(self.__key), modes.CBC(self.__iv), backend=backend) decryptor = cypher.decryptor() result = decryptor.update(b) + decryptor.finalize() if strip_padding: result = result[:-result[-1]] return result def _encrypt(self, b): """Encrypt a byte string. Uses the AES 256-bit symmetric key cypher. Args: b: the byte string to encrypt. Returns: The encrypted data as a byte string. """ from cryptography.hazmat.primitives.ciphers \ import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend backend = default_backend() cypher = Cipher( algorithms.AES(self.__key), modes.CBC(self.__iv), backend=backend) encryptor = cypher.encryptor() pad_length = 16 - (len(b) % 16) b += bytes([pad_length]) * pad_length result = encryptor.update(b) + encryptor.finalize() return result def _produce_key(self, passphrase): """Generate encrypt key. Creates the encryption key using a 256-bit SHA2 hash and a key stretching mechanism that repeatedly hashes the previous hash result concatenated with the passphrase. Args: passphrase: the passphrase string. Returns: The encryption key as a byte string 32 bytes in length. """ from hashlib import sha256 pp = bytes(passphrase, 'utf-8') hash_alg = sha256(pp) for i in range(self._get_key_stretches()): d = hash_alg.digest() hash_alg.update(d + pp) return hash_alg.digest() def _produce_iv(self, key): """Generate initialization vector. Generates the initialization vector for use by the symmetric key encryption algorithm used to encrypt and decrypt the data file. Args: key: the encryption key Returns: The initialization vector as a byte string 16 bytes in length. """ b = key[31] & 0x0F e = b + 16 iv = key[b:e] return iv def _validate_header(self, cleartext_header, decrypted_header): """Whether header of data file is OK and matches expected values. Checks that the decrypted header values match expectation, and that the cleartext header is identical to the decrypted header. Args: cleartext_header: A bytes object containing the first 16 bytes of the data file. decrypted_header: A bytes object containing the next 16 bytes of the data file, decrypted. Raises: DecryptionError: When the decrypted header doesn't match expected values, so the passphrase is probably incorrect, or the key stretch count (of hash iterations) is incorrect, or both. """ import struct magic_number1 = struct.unpack("!I", decrypted_header[:4])[0] # file_version = struct.unpack("!I", decrypted_header[4:8])[0] # key_stretches = struct.unpack("!I", decrypted_header[8:12])[0] magic_number2 = struct.unpack("!I", decrypted_header[12:])[0] if (self.__magic_number != magic_number1 or self.__magic_number != magic_number2): raise DecryptionError() if cleartext_header != decrypted_header: raise FileCorruptionError() # -------------------------------------------------------------------------+ # public methods # -------------------------------------------------------------------------+ def load(self, filepath): """Load ClientData objects from encrypted file. Load the list of ClientData objects from a file encrypted with the passphrase. Args: filepath: the fully qualified path to the data file. Returns: The list of ClientData objects found in the data file. """ cypher_text = b'' with open(filepath, 'rb') as f: header = f.read(16) cypher_text = f.read() data = self._decrypt(cypher_text) decrypted_header = data[:16] self._validate_header(header, decrypted_header) plain_text = str(data[16:], 'utf-8') cds = json.loads(plain_text, cls=ClientDataDecoder) if cds is None: cds = [] return cds def save(self, filepath, client_data_list, new_passphrase=None): """Store ClientData objects into encrypted file. Store the list of ClientData objects as a JSON document in a file encrypted with the passphrase. Args: filepath: the fully qualified path to the data file. client_data_list: a list of ClientData objects to store in the data file. """ import struct plain_text = json.dumps( client_data_list, sort_keys=True, indent=4, separators=(',', ': '), cls=ClientDataEncoder) header = b''.join([ struct.pack("!I", self.__magic_number), struct.pack("!I", self.__file_version), struct.pack("!I", self.__key_stretches), struct.pack("!I", self.__magic_number)]) data = b''.join([ header, bytes(plain_text, 'utf-8')]) if new_passphrase is not None: self.__key = self._produce_key(new_passphrase) self.__iv = self._produce_iv(self.__key) cypher_text = self._encrypt(data) with open(filepath, 'wb') as f: f.write(header) f.write(cypher_text) def validate(self, filepath): """Decrypt the data file header, and validate the file is readable. Decrypt the initial part of the file and validate it to ensure the passphrase is correct. Args: filepath: the fully qualified path to the data file. Returns: True if the decrypted data in the leading section of the file matches expected values; otherwise, returns False (meaning the phassphrase is invalid, or the number of key stretching hashes is invalid, or both are invalid). """ header_bytes = b'' cypher_bytes = b'' with open(filepath, 'rb') as f: header_bytes = f.read(16) cypher_bytes = f.read(16) data = self._decrypt(cypher_bytes, strip_padding=False) try: self._validate_header(header_bytes, data) except DecryptionError: return False return True