from abc import abstractmethod, ABC from collections import OrderedDict from struct import Struct from itertools import zip_longest import struct from segpy import __version__ from segpy.datatypes import SEG_Y_TYPE_TO_CTYPE from segpy.util import pairwise, intervals_partially_overlap, complementary_intervals, all_equal def size_of(t): return t.SIZE def compile_struct(header_format_class, start_offset=0, length_in_bytes=None, endian='>'): """Compile a struct description from a record. Args: header_format_class: A header_format class. start_offset: Optional start offset for the header in bytes. Indicates the position of the start of the header in the same reference frame as which the field offsets are given. length_in_bytes: Optional length in bytes for the header. If the supplied header described a format shorter than this value the returned format will be padded with placeholders for bytes to be discarded. If the value is less than the minimum required for the format described by header_format_class an error will be raised. endian: '>' for big-endian data (the standard and default), '<' for little-endian (non-standard). Returns: A two-tuple containing in the zeroth element a format string which can be used with the struct.unpack function, and in the second element containing a list-of-lists for field names. Each item in the outer list corresponds to an element of the tuple of data values returned by struct.unpack(); each name associated with that index is a field to which the unpacked value should be assigned. Usage: format, allocations = compile_struct(TraceHeaderFormat) values = struct.unpack(format) field_names_to_values = {} for field_names, value in zip(allocations, values): for field_name in field_names: field_names_to_values[field_name] = value header = Header(**field_names_to_values) Raises: ValueError: If header_format_class defines no fields. ValueError: If header_format_class contains fields which overlap but are not exactly coincident. ValueError: If header_format_class contains coincident fields of different types. ValueError: If header_format_class described a format longer than length_in_bytes. """ if start_offset < 0: raise ValueError("start_offset {} is less than zero".format(start_offset)) if isinstance(length_in_bytes, int) and length_in_bytes < 1: raise ValueError("length_in_bytes {} is less than one".format(length_in_bytes)) fields = [getattr(header_format_class, name) for name in header_format_class.ordered_field_names()] sorted_fields = sorted(fields, key=lambda f: f.offset) if len(sorted_fields) < 1: raise TypeError("Header format class {!r} defines no fields".format(header_format_class.__name__)) if len(sorted_fields) > 1: for a, b in pairwise(sorted_fields): if intervals_partially_overlap(range(a.offset, a.offset + size_of(a.value_type)), range(b.offset, b.offset + size_of(b.value_type))): raise ValueError("Fields {!r} at offset {} and {!r} at offset {} of {} are distinct but overlap." .format(a.name, a.offset, b.name, b.offset, header_format_class.__name__)) last_field = sorted_fields[-1] defined_length = (last_field.offset - start_offset) + size_of(last_field.value_type) specified_length = defined_length if (length_in_bytes is None) else length_in_bytes padding_length = specified_length - defined_length if padding_length < 0: raise ValueError("Header length {!r} bytes defined by {!r} is less than specified length in bytes {!r}" .format(defined_length, header_format_class.__name__, specified_length)) offset_to_fields = OrderedDict() for field in sorted_fields: relative_offset = field.offset - start_offset # relative_offser is zero-based if relative_offset not in offset_to_fields: offset_to_fields[relative_offset] = [] if len(offset_to_fields[relative_offset]) > 0: if offset_to_fields[relative_offset][0].value_type is not field.value_type: raise TypeError("Coincident fields {!r} and {!r} at offset {} have different types {} and {}" .format(offset_to_fields[relative_offset][0], field, offset_to_fields[relative_offset][0].offset, offset_to_fields[relative_offset][0].value_type.__name__, field.value_type.__name__)) offset_to_fields[relative_offset].append(field) # Create a list of ranges where each range spans the byte indexes covered by each field field_spans = [range(offset, offset + size_of(fields[0].value_type)) for offset, fields in offset_to_fields.items()] gap_intervals = complementary_intervals(field_spans, start=0, stop=specified_length) # One-based indexes # Create a format string usable with the struct module format_chunks = [endian] representative_fields = (fields[0] for fields in offset_to_fields.values()) for gap_interval, field in zip_longest(gap_intervals, representative_fields, fillvalue=None): gap_length = len(gap_interval) if gap_length > 0: format_chunks.append('x' * gap_length) if field is not None: format_chunks.append(SEG_Y_TYPE_TO_CTYPE[field.value_type.SEG_Y_TYPE]) cformat = ''.join(format_chunks) # Create a list of mapping item index to field names. # [0] -> ['field_1', 'field_2'] # [1] -> ['field_3'] # [2] -> ['field_4'] field_name_allocations = [[field.name for field in fields] for fields in offset_to_fields.values()] return cformat, field_name_allocations def make_header_packer(header_format_class, endian='>'): cformat, field_name_allocations = compile_struct( header_format_class, getattr(header_format_class, 'START_OFFSET_IN_BYTES', 0), getattr(header_format_class, 'LENGTH_IN_BYTES', None), endian) structure = Struct(cformat) one_to_one = all(len(fields) == 1 for fields in field_name_allocations) if one_to_one: return BijectiveHeaderPacker(header_format_class, structure, field_name_allocations) return SurjectiveHeaderPacker(header_format_class, structure, field_name_allocations) class HeaderPacker(ABC): """Packing and unpacking header instances.""" def __init__(self, header_format_class, structure, field_name_allocations): self._header_format_class = header_format_class self._structure = structure self._field_name_allocations = field_name_allocations def __getstate__(self): state = self.__dict__.copy() state['__version__'] = __version__ state['_structure_format'] = self._structure.format del state['_structure'] return state def __setstate__(self, state): if state['__version__'] != __version__: raise TypeError("Cannot unpickle {} version {} into version {}" .format(self.__class__.__name__, state['__version__'], __version__)) del state['__version__'] structure = Struct(state['_structure_format']) state['_structure'] = structure del state['_structure_format'] self.__dict__.update(state) @property def header_format_class(self): return self._header_format_class def pack(self, header): """Pack a header into a buffer. """ if not isinstance(header, self._header_format_class): raise TypeError("{}({}) cannot pack header of type {}.".format( self.__class__.__name__, self._header_format_class.__name__, header.__class__.__name__ )) return self._pack(header) def unpack(self, buffer): """Unpack a header into a header object. Overwrites any existing header field values with new values obtained from the buffer. Returns: The header object. """ try: values = self._structure.unpack(buffer) except struct.error as e: raise ValueError("Buffer of length {} too short" .format(len(buffer), str(e).capitalize())) from e else: return self._unpack(values) @abstractmethod def _pack(self, header): raise NotImplementedError @abstractmethod def _unpack(self, values): raise NotImplementedError def __repr__(self): return "{}({})".format( self.__class__.__name__, self._header_format_class.__name__) class BijectiveHeaderPacker(HeaderPacker): """One-to-one packing/unpacking of serialised values to header fields.""" def _pack(self, header): values = [getattr(header, names[0]) for names in self._field_name_allocations] return self._structure.pack(*values) def _unpack(self, values): return self._header_format_class(*values) class SurjectiveHeaderPacker(HeaderPacker): """One-to-many unpacking of serialised values to header fields.""" def _pack(self, header): for names in self._field_name_allocations: field_values = [getattr(header, name) for name in names] if not all_equal(field_values): raise ValueError("fields {} have unequal values {}" .format(', '.join(names), ', '.join(map(str, field_values)))) values = [getattr(header, names[0]) for names in self._field_name_allocations] return self._structure.pack(*values) def _unpack(self, values): kwargs = {name: value for names, value in zip(self._field_name_allocations, values) for name in names} return self._header_format_class(**kwargs)