import numpy as np import os import tempfile from fjcommon import functools_ext as ft from fjcommon import timer from fjcommon import no_op import itertools import arithmetic_coding as ac import probclass def encode_decode_to_file_ctx(syms, prediction_net: probclass.PredictionNetwork, syms_format='HWC', verbose=False): """ Encode symbols with arithmetic coding to disk. :param syms: HWC or CHW depending on syms_format, symbols of one image. Or BHWC, BCHW, in which case the number of bits needed for all batches is returned. :param prediction_net: arithmetic coding to be correct). :return: number of bits to encode all symbols in `syms` """ _print = print if verbose else no_op.NoOp() if len(syms.shape) == 4: num_batches = syms.shape[0] return np.sum([encode_decode_to_file_ctx(syms[b, ...], prediction_net, syms_format, verbose) for b in range(num_batches)]) assert len(syms.shape) == 3, 'Expected HWC or CHW' assert syms_format in ('HWC', 'CHW') if syms_format == 'HWC': _print('Transposing symbols for encoding...') syms = np.transpose(syms, (2, 0, 1)) # --- _print('Preparing encode...') foutid, fout_p = tempfile.mkstemp() ctx_shape = prediction_net.input_ctx_shape get_freqs = ft.compose(ac.SimpleFrequencyTable, prediction_net.get_freqs) get_pr = prediction_net.get_pr # encode with timer.execute('Encoding time [s]'): _print('Encoding symbols of shape {} ({} symbols) with context shape {}...'.format( syms.shape, np.prod(syms.shape), ctx_shape)) syms_padded = prediction_net.pad_symbols_volume(syms) virtual_num_bits, first_sym, theoretical_bit_cost = _encode( foutid, syms_padded, ctx_shape, get_freqs, get_pr, _print) assert abs(virtual_num_bits - theoretical_bit_cost) < 50, 'Virtual: {} -- Theoretical: {}'.format( virtual_num_bits, theoretical_bit_cost) # bit count actual_num_bits = os.path.getsize(fout_p) * 8 assert actual_num_bits == virtual_num_bits, '{} != {}'.format( actual_num_bits, virtual_num_bits) # decode with timer.execute('Decoding time [s]'): _print('Decoding symbols to shape {}, first_sym={}...'.format( syms_padded.shape, first_sym)) syms_dec_padded = _decode(fout_p, syms_padded.shape, ctx_shape, first_sym, get_freqs, _print) syms_dec = prediction_net.undo_pad_symbols_volume(syms_dec_padded) # checkin' (takes no time) np.testing.assert_array_equal(syms, syms_dec) _print('Decoded symbols match input!') # cleanup os.remove(fout_p) return actual_num_bits def _new_ctx_itr(syms, ctx_shape): return probclass.iter_over_blocks(syms, ctx_shape) def _get_num_ctxs(syms_shape, ctx_shape): return probclass.num_blocks(syms_shape, ctx_shape) def _new_ctx_sym_itr(syms, ctx_shape): assert len(ctx_shape) == 3 _, h, w = ctx_shape for ctx in _new_ctx_itr(syms, ctx_shape): # symbol is in the last depth dimension in the center sym = ctx[-1, h // 2, w // 2] yield ctx, sym def _new_sym_idxs_itr(syms_shape, ctx_size): D, H, W = syms_shape pad = ctx_size // 2 return itertools.product( range(pad, D), # D dimension is not padded range(pad, H - pad), range(pad, W - pad)) # yields tuples (d, h, w) def _encode(foutid, syms, ctx_shape, get_freqs, get_pr, printer): """ :param foutid: :param syms: CHW, padded :param ctx_shape: :param get_freqs: :param get_pr: :return: """ with open(foutid, 'wb') as fout: bit_out = ac.CountingBitOutputStream( bit_out=ac.BitOutputStream(fout)) enc = ac.ArithmeticEncoder(bit_out) ctx_sym_itr = _new_ctx_sym_itr(syms, ctx_shape=ctx_shape) # First sym is stored separately using log2(L) bits or sth first_ctx, first_sym = next(ctx_sym_itr) first_pr = get_pr(first_ctx) first_bc = -np.log2(first_pr[first_sym]) theoretical_bit_cost = first_bc num_ctxs = _get_num_ctxs(syms.shape, ctx_shape) # Encode other symbols for i, (ctx, sym) in enumerate(ctx_sym_itr): freqs = get_freqs(ctx) pr = get_pr(ctx) theoretical_bit_cost += -np.log2(pr[sym]) enc.write(freqs, sym) if i % 1000 == 0: printer('\rFeeding context for symbol #{}/{}...'.format(i, num_ctxs), end='', flush=True) printer('\r\033[K', end='') # clear line enc.finish() bit_out.close() return bit_out.num_bits, first_sym, theoretical_bit_cost def _decode(fout_p, symbols_shape_padded, ctx_shape, first_sym, get_freqs, printer): # Idea: # have a matrix symbols_decoded, initially all zeros. # put first_sym into symbols_decoded # use a normal ctx_itr to retrieve the current context from symbols_decoded # use symbol_idx_itr to get the index of the next decoded symbol # write the decoded symbol into symbols_decoded, then advancethe ctx_itr to get the next context with open(fout_p, 'rb') as fin: bitin = ac.BitInputStream(fin) dec = ac.ArithmeticDecoder(bitin) symbols_decoded = np.zeros(symbols_shape_padded, dtype=np.int32) ctx_itr = _new_ctx_itr(symbols_decoded, ctx_shape) ctx_size = probclass.context_size_from_context_shape(ctx_shape) sym_idxs_itr = _new_sym_idxs_itr(symbols_shape_padded, ctx_size=ctx_size) next(ctx_itr) # skip first ctx symbols_decoded[next(sym_idxs_itr)] = first_sym # write first_sym num_ctxs = _get_num_ctxs(symbols_shape_padded, ctx_shape) for i, (current_ctx, next_decoded_sym_idx) in enumerate(zip(ctx_itr, sym_idxs_itr)): freqs = get_freqs(current_ctx) symbol = dec.read(freqs) symbols_decoded[next_decoded_sym_idx] = symbol if i % 1000 == 0: printer('\rFeeding context for symbol #{}/{}...'.format(i, num_ctxs), end='', flush=True) printer('\r\033[K', end='') # clear line return symbols_decoded