#! /usr/bin/env python3 # -*- coding: utf-8 -*- # File : batch.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 02/18/2018 # # This file is part of Jacinle. # Distributed under terms of the MIT license. import collections import numpy as np __all__ = ['batchify', 'unbatchify'] def batchify(inputs): first = inputs[0] if isinstance(first, (tuple, list, collections.UserList)): return [batchify([ele[i] for ele in inputs]) for i in range(len(first))] elif isinstance(first, (collections.Mapping, collections.UserDict)): return {k: batchify([ele[k] for ele in inputs]) for k in first} return np.stack(inputs) def unbatchify(inputs): if isinstance(inputs, (tuple, list, collections.UserList)): outputs = [unbatchify(e) for e in inputs] return list(map(list, zip(*outputs))) elif isinstance(inputs, (collections.Mapping, collections.UserDict)): outputs = {k: unbatchify(v) for k, v in inputs.items()} first = outputs[0] return [{k: outputs[k][i] for k in inputs} for i in range(len(first))] return list(inputs)