import errno import functools import os import re import sys import threading import traceback import ruamel.yaml as yaml import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable from mindpark.utility.attrdict import AttrDict def ensure_directory(directory): directory = os.path.expanduser(directory) try: os.makedirs(directory) except OSError as e: if e.errno != errno.EEXIST: raise e def clamp(value, min_, max_): return max(min_, min(value, max_)) def use_attrdicts(obj): if isinstance(obj, dict): return AttrDict({k: use_attrdicts(v) for k, v in obj.items()}) elif isinstance(obj, list): return [use_attrdicts(x) for x in obj] return obj def merge_dicts(*mappings): """ Recursively join mapping objects. Later values override earlier ones. """ merged = {} for mapping in mappings: for key, value in mapping.items(): if key in merged: if isinstance(merged[key], dict) != isinstance(value, dict): raise ValueError('Cannot merge dict with value.') if isinstance(merged[key], dict): value = merge_dicts(merged[key], value) merged[key] = value return merged def sum_dicts(*mappings): summed = {} for key, value in mappings.items(): if key not in summed: summed[key] = value else: summed[key] = summed[key] + value return summed def lazy_property(function): attribute = '_' + function.__name__ @property @functools.wraps(function) def wrapper(self): if not hasattr(self, attribute): setattr(self, attribute, function(self)) return getattr(self, attribute) return wrapper def get_subdirs(directory): subdirs = os.listdir(directory) subdirs = [os.path.join(directory, x) for x in subdirs] subdirs = [x for x in subdirs if os.path.isdir(x)] return sorted(subdirs) def color_stack_trace(): def excepthook(type_, value, trace): text = ''.join(traceback.format_exception(type_, value, trace)) try: from pygments import highlight from pygments.lexers import get_lexer_by_name from pygments.formatters import TerminalFormatter lexer = get_lexer_by_name('pytb', stripall=True) formatter = TerminalFormatter() sys.stderr.write(highlight(text, lexer, formatter)) except Exception: sys.stderr.write(text) sys.stderr.write('Failed to colorize the traceback.') sys.excepthook = excepthook setup_thread_excepthook() def setup_thread_excepthook(): """ Workaround for `sys.excepthook` thread bug from: http://bugs.python.org/issue1230540 Call once from the main thread before creating any threads. """ init_original = threading.Thread.__init__ def init(self, *args, **kwargs): init_original(self, *args, **kwargs) run_original = self.run def run_with_except_hook(*args2, **kwargs2): try: run_original(*args2, **kwargs2) except Exception: sys.excepthook(*sys.exc_info()) self.run = run_with_except_hook threading.Thread.__init__ = init def natural_sorted(collection, key=lambda x: x): convert = lambda text: int(text) if text.isdigit() else text.lower() natural_key = lambda x: [convert(y) for y in re.split('([0-9]+)', key(x))] return sorted(collection, key=natural_key) def flatten(collection): if collection == []: return collection if isinstance(collection[0], list): return flatten(collection[0]) + flatten(collection[1:]) return collection[:1] + flatten(collection[1:]) def dump_yaml(data, *path): def convert(obj): if isinstance(obj, dict): obj = {k: v for k, v in obj.items() if not k.startswith('_')} return {convert(k): convert(v) for k, v in obj.items()} if isinstance(obj, list): return [convert(x) for x in obj] if isinstance(obj, type): return obj.__name__ return obj filename = os.path.join(*path) ensure_directory(os.path.dirname(filename)) with open(filename, 'w') as file_: yaml.safe_dump(convert(data), file_, default_flow_style=False) def print_headline(*message, style='-', minwidth=40): message = ' '.join(message) width = max(minwidth, len(message)) print('\n' + style * width) print(message) print(style * width + '\n', flush=True) def read_yaml(*path): path = os.path.join(*path) with open(path) as file_: return use_attrdicts(yaml.load(file_)) def aggregate(values, borders, reducer): groups = [] for start, stop in zip(borders[:-1], borders[1:]): groups.append(reducer(values[start: stop])) groups = np.array(groups) return groups def add_color_bar(ax, img): divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='7%', pad=0.1) bar = plt.colorbar(img, cax=cax) return bar def synchronized(function): """ Decorator for methods to restrict the parallel access per instance. For classmethods, the restriction is per class, making it global. The decorator should not be used for staticmethods. """ @functools.wraps(function) def decorator(self, *args, **kwargs): lock = '_{}_lock'.format(function.__name__) if not hasattr(self, lock): setattr(self, lock, threading.Lock()) with getattr(self, lock): function(self, *args, **kwargs) return decorator class OptionalContext: def __init__(self, context): self._context = context def __enter__(self, *args, **kwargs): if self._context: self._context.__enter__(*args, **kwargs) def __exit__(self, *args, **kwargs): if self._context: self._context.__exit__(*args, **kwargs)